chiark / gitweb /
Refactor Series.new_patch
[stgit] / stgit / stack.py
1 """Basic quilt-like functionality
2 """
3
4 __copyright__ = """
5 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License version 2 as
9 published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 """
20
21 import sys, os, re
22 from email.Utils import formatdate
23
24 from stgit.utils import *
25 from stgit.out import *
26 from stgit import git, basedir, templates
27 from stgit.config import config
28 from shutil import copyfile
29
30
31 # stack exception class
32 class StackException(Exception):
33     pass
34
35 class FilterUntil:
36     def __init__(self):
37         self.should_print = True
38     def __call__(self, x, until_test, prefix):
39         if until_test(x):
40             self.should_print = False
41         if self.should_print:
42             return x[0:len(prefix)] != prefix
43         return False
44
45 #
46 # Functions
47 #
48 __comment_prefix = 'STG:'
49 __patch_prefix = 'STG_PATCH:'
50
51 def __clean_comments(f):
52     """Removes lines marked for status in a commit file
53     """
54     f.seek(0)
55
56     # remove status-prefixed lines
57     lines = f.readlines()
58
59     patch_filter = FilterUntil()
60     until_test = lambda t: t == (__patch_prefix + '\n')
61     lines = [l for l in lines if patch_filter(l, until_test, __comment_prefix)]
62
63     # remove empty lines at the end
64     while len(lines) != 0 and lines[-1] == '\n':
65         del lines[-1]
66
67     f.seek(0); f.truncate()
68     f.writelines(lines)
69
70 # TODO: move this out of the stgit.stack module, it is really for
71 # higher level commands to handle the user interaction
72 def edit_file(series, line, comment, show_patch = True):
73     fname = '.stgitmsg.txt'
74     tmpl = templates.get_template('patchdescr.tmpl')
75
76     f = file(fname, 'w+')
77     if line:
78         print >> f, line
79     elif tmpl:
80         print >> f, tmpl,
81     else:
82         print >> f
83     print >> f, __comment_prefix, comment
84     print >> f, __comment_prefix, \
85           'Lines prefixed with "%s" will be automatically removed.' \
86           % __comment_prefix
87     print >> f, __comment_prefix, \
88           'Trailing empty lines will be automatically removed.'
89
90     if show_patch:
91        print >> f, __patch_prefix
92        # series.get_patch(series.get_current()).get_top()
93        diff_str = git.diff(rev1 = series.get_patch(series.get_current()).get_bottom())
94        f.write(diff_str)
95
96     #Vim modeline must be near the end.
97     print >> f, __comment_prefix, 'vi: set textwidth=75 filetype=diff nobackup:'
98     f.close()
99
100     call_editor(fname)
101
102     f = file(fname, 'r+')
103
104     __clean_comments(f)
105     f.seek(0)
106     result = f.read()
107
108     f.close()
109     os.remove(fname)
110
111     return result
112
113 #
114 # Classes
115 #
116
117 class StgitObject:
118     """An object with stgit-like properties stored as files in a directory
119     """
120     def _set_dir(self, dir):
121         self.__dir = dir
122     def _dir(self):
123         return self.__dir
124
125     def create_empty_field(self, name):
126         create_empty_file(os.path.join(self.__dir, name))
127
128     def _get_field(self, name, multiline = False):
129         id_file = os.path.join(self.__dir, name)
130         if os.path.isfile(id_file):
131             line = read_string(id_file, multiline)
132             if line == '':
133                 return None
134             else:
135                 return line
136         else:
137             return None
138
139     def _set_field(self, name, value, multiline = False):
140         fname = os.path.join(self.__dir, name)
141         if value and value != '':
142             write_string(fname, value, multiline)
143         elif os.path.isfile(fname):
144             os.remove(fname)
145
146     
147 class Patch(StgitObject):
148     """Basic patch implementation
149     """
150     def __init_refs(self):
151         self.__top_ref = self.__refs_base + '/' + self.__name
152         self.__log_ref = self.__top_ref + '.log'
153
154     def __init__(self, name, series_dir, refs_base):
155         self.__series_dir = series_dir
156         self.__name = name
157         self._set_dir(os.path.join(self.__series_dir, self.__name))
158         self.__refs_base = refs_base
159         self.__init_refs()
160
161     def create(self):
162         os.mkdir(self._dir())
163         self.create_empty_field('bottom')
164         self.create_empty_field('top')
165
166     def delete(self):
167         for f in os.listdir(self._dir()):
168             os.remove(os.path.join(self._dir(), f))
169         os.rmdir(self._dir())
170         git.delete_ref(self.__top_ref)
171         if git.ref_exists(self.__log_ref):
172             git.delete_ref(self.__log_ref)
173
174     def get_name(self):
175         return self.__name
176
177     def rename(self, newname):
178         olddir = self._dir()
179         old_top_ref = self.__top_ref
180         old_log_ref = self.__log_ref
181         self.__name = newname
182         self._set_dir(os.path.join(self.__series_dir, self.__name))
183         self.__init_refs()
184
185         git.rename_ref(old_top_ref, self.__top_ref)
186         if git.ref_exists(old_log_ref):
187             git.rename_ref(old_log_ref, self.__log_ref)
188         os.rename(olddir, self._dir())
189
190     def __update_top_ref(self, ref):
191         git.set_ref(self.__top_ref, ref)
192
193     def __update_log_ref(self, ref):
194         git.set_ref(self.__log_ref, ref)
195
196     def update_top_ref(self):
197         top = self.get_top()
198         if top:
199             self.__update_top_ref(top)
200
201     def get_old_bottom(self):
202         return self._get_field('bottom.old')
203
204     def get_bottom(self):
205         return self._get_field('bottom')
206
207     def set_bottom(self, value, backup = False):
208         if backup:
209             curr = self._get_field('bottom')
210             self._set_field('bottom.old', curr)
211         self._set_field('bottom', value)
212
213     def get_old_top(self):
214         return self._get_field('top.old')
215
216     def get_top(self):
217         return self._get_field('top')
218
219     def set_top(self, value, backup = False):
220         if backup:
221             curr = self._get_field('top')
222             self._set_field('top.old', curr)
223         self._set_field('top', value)
224         self.__update_top_ref(value)
225
226     def restore_old_boundaries(self):
227         bottom = self._get_field('bottom.old')
228         top = self._get_field('top.old')
229
230         if top and bottom:
231             self._set_field('bottom', bottom)
232             self._set_field('top', top)
233             self.__update_top_ref(top)
234             return True
235         else:
236             return False
237
238     def get_description(self):
239         return self._get_field('description', True)
240
241     def set_description(self, line):
242         self._set_field('description', line, True)
243
244     def get_authname(self):
245         return self._get_field('authname')
246
247     def set_authname(self, name):
248         self._set_field('authname', name or git.author().name)
249
250     def get_authemail(self):
251         return self._get_field('authemail')
252
253     def set_authemail(self, email):
254         self._set_field('authemail', email or git.author().email)
255
256     def get_authdate(self):
257         date = self._get_field('authdate')
258         if not date:
259             return date
260
261         if re.match('[0-9]+\s+[+-][0-9]+', date):
262             # Unix time (seconds) + time zone
263             secs_tz = date.split()
264             date = formatdate(int(secs_tz[0]))[:-5] + secs_tz[1]
265
266         return date
267
268     def set_authdate(self, date):
269         self._set_field('authdate', date or git.author().date)
270
271     def get_commname(self):
272         return self._get_field('commname')
273
274     def set_commname(self, name):
275         self._set_field('commname', name or git.committer().name)
276
277     def get_commemail(self):
278         return self._get_field('commemail')
279
280     def set_commemail(self, email):
281         self._set_field('commemail', email or git.committer().email)
282
283     def get_log(self):
284         return self._get_field('log')
285
286     def set_log(self, value, backup = False):
287         self._set_field('log', value)
288         self.__update_log_ref(value)
289
290 # The current StGIT metadata format version.
291 FORMAT_VERSION = 2
292
293 class PatchSet(StgitObject):
294     def __init__(self, name = None):
295         try:
296             if name:
297                 self.set_name (name)
298             else:
299                 self.set_name (git.get_head_file())
300             self.__base_dir = basedir.get()
301         except git.GitException, ex:
302             raise StackException, 'GIT tree not initialised: %s' % ex
303
304         self._set_dir(os.path.join(self.__base_dir, 'patches', self.get_name()))
305
306     def get_name(self):
307         return self.__name
308     def set_name(self, name):
309         self.__name = name
310
311     def _basedir(self):
312         return self.__base_dir
313
314     def get_head(self):
315         """Return the head of the branch
316         """
317         crt = self.get_current_patch()
318         if crt:
319             return crt.get_top()
320         else:
321             return self.get_base()
322
323     def get_protected(self):
324         return os.path.isfile(os.path.join(self._dir(), 'protected'))
325
326     def protect(self):
327         protect_file = os.path.join(self._dir(), 'protected')
328         if not os.path.isfile(protect_file):
329             create_empty_file(protect_file)
330
331     def unprotect(self):
332         protect_file = os.path.join(self._dir(), 'protected')
333         if os.path.isfile(protect_file):
334             os.remove(protect_file)
335
336     def __branch_descr(self):
337         return 'branch.%s.description' % self.get_name()
338
339     def get_description(self):
340         return config.get(self.__branch_descr()) or ''
341
342     def set_description(self, line):
343         if line:
344             config.set(self.__branch_descr(), line)
345         else:
346             config.unset(self.__branch_descr())
347
348     def head_top_equal(self):
349         """Return true if the head and the top are the same
350         """
351         crt = self.get_current_patch()
352         if not crt:
353             # we don't care, no patches applied
354             return True
355         return git.get_head() == crt.get_top()
356
357     def is_initialised(self):
358         """Checks if series is already initialised
359         """
360         return bool(config.get(self.format_version_key()))
361
362
363 def shortlog(patches):
364     log = ''.join(Run('git-log', '--pretty=short',
365                       p.get_top(), '^%s' % p.get_bottom()).raw_output()
366                   for p in patches)
367     return Run('git-shortlog').raw_input(log).raw_output()
368
369 class Series(PatchSet):
370     """Class including the operations on series
371     """
372     def __init__(self, name = None):
373         """Takes a series name as the parameter.
374         """
375         PatchSet.__init__(self, name)
376
377         # Update the branch to the latest format version if it is
378         # initialized, but don't touch it if it isn't.
379         self.update_to_current_format_version()
380
381         self.__refs_base = 'refs/patches/%s' % self.get_name()
382
383         self.__applied_file = os.path.join(self._dir(), 'applied')
384         self.__unapplied_file = os.path.join(self._dir(), 'unapplied')
385         self.__hidden_file = os.path.join(self._dir(), 'hidden')
386
387         # where this series keeps its patches
388         self.__patch_dir = os.path.join(self._dir(), 'patches')
389
390         # trash directory
391         self.__trash_dir = os.path.join(self._dir(), 'trash')
392
393     def format_version_key(self):
394         return 'branch.%s.stgit.stackformatversion' % self.get_name()
395
396     def update_to_current_format_version(self):
397         """Update a potentially older StGIT directory structure to the
398         latest version. Note: This function should depend as little as
399         possible on external functions that may change during a format
400         version bump, since it must remain able to process older formats."""
401
402         branch_dir = os.path.join(self._basedir(), 'patches', self.get_name())
403         def get_format_version():
404             """Return the integer format version number, or None if the
405             branch doesn't have any StGIT metadata at all, of any version."""
406             fv = config.get(self.format_version_key())
407             ofv = config.get('branch.%s.stgitformatversion' % self.get_name())
408             if fv:
409                 # Great, there's an explicitly recorded format version
410                 # number, which means that the branch is initialized and
411                 # of that exact version.
412                 return int(fv)
413             elif ofv:
414                 # Old name for the version info, upgrade it
415                 config.set(self.format_version_key(), ofv)
416                 config.unset('branch.%s.stgitformatversion' % self.get_name())
417                 return int(ofv)
418             elif os.path.isdir(os.path.join(branch_dir, 'patches')):
419                 # There's a .git/patches/<branch>/patches dirctory, which
420                 # means this is an initialized version 1 branch.
421                 return 1
422             elif os.path.isdir(branch_dir):
423                 # There's a .git/patches/<branch> directory, which means
424                 # this is an initialized version 0 branch.
425                 return 0
426             else:
427                 # The branch doesn't seem to be initialized at all.
428                 return None
429         def set_format_version(v):
430             out.info('Upgraded branch %s to format version %d' % (self.get_name(), v))
431             config.set(self.format_version_key(), '%d' % v)
432         def mkdir(d):
433             if not os.path.isdir(d):
434                 os.makedirs(d)
435         def rm(f):
436             if os.path.exists(f):
437                 os.remove(f)
438         def rm_ref(ref):
439             if git.ref_exists(ref):
440                 git.delete_ref(ref)
441
442         # Update 0 -> 1.
443         if get_format_version() == 0:
444             mkdir(os.path.join(branch_dir, 'trash'))
445             patch_dir = os.path.join(branch_dir, 'patches')
446             mkdir(patch_dir)
447             refs_base = 'refs/patches/%s' % self.get_name()
448             for patch in (file(os.path.join(branch_dir, 'unapplied')).readlines()
449                           + file(os.path.join(branch_dir, 'applied')).readlines()):
450                 patch = patch.strip()
451                 os.rename(os.path.join(branch_dir, patch),
452                           os.path.join(patch_dir, patch))
453                 Patch(patch, patch_dir, refs_base).update_top_ref()
454             set_format_version(1)
455
456         # Update 1 -> 2.
457         if get_format_version() == 1:
458             desc_file = os.path.join(branch_dir, 'description')
459             if os.path.isfile(desc_file):
460                 desc = read_string(desc_file)
461                 if desc:
462                     config.set('branch.%s.description' % self.get_name(), desc)
463                 rm(desc_file)
464             rm(os.path.join(branch_dir, 'current'))
465             rm_ref('refs/bases/%s' % self.get_name())
466             set_format_version(2)
467
468         # Make sure we're at the latest version.
469         if not get_format_version() in [None, FORMAT_VERSION]:
470             raise StackException('Branch %s is at format version %d, expected %d'
471                                  % (self.get_name(), get_format_version(), FORMAT_VERSION))
472
473     def __patch_name_valid(self, name):
474         """Raise an exception if the patch name is not valid.
475         """
476         if not name or re.search('[^\w.-]', name):
477             raise StackException, 'Invalid patch name: "%s"' % name
478
479     def get_patch(self, name):
480         """Return a Patch object for the given name
481         """
482         return Patch(name, self.__patch_dir, self.__refs_base)
483
484     def get_current_patch(self):
485         """Return a Patch object representing the topmost patch, or
486         None if there is no such patch."""
487         crt = self.get_current()
488         if not crt:
489             return None
490         return self.get_patch(crt)
491
492     def get_current(self):
493         """Return the name of the topmost patch, or None if there is
494         no such patch."""
495         try:
496             applied = self.get_applied()
497         except StackException:
498             # No "applied" file: branch is not initialized.
499             return None
500         try:
501             return applied[-1]
502         except IndexError:
503             # No patches applied.
504             return None
505
506     def get_applied(self):
507         if not os.path.isfile(self.__applied_file):
508             raise StackException, 'Branch "%s" not initialised' % self.get_name()
509         return read_strings(self.__applied_file)
510
511     def get_unapplied(self):
512         if not os.path.isfile(self.__unapplied_file):
513             raise StackException, 'Branch "%s" not initialised' % self.get_name()
514         return read_strings(self.__unapplied_file)
515
516     def get_hidden(self):
517         if not os.path.isfile(self.__hidden_file):
518             return []
519         return read_strings(self.__hidden_file)
520
521     def get_base(self):
522         # Return the parent of the bottommost patch, if there is one.
523         if os.path.isfile(self.__applied_file):
524             bottommost = file(self.__applied_file).readline().strip()
525             if bottommost:
526                 return self.get_patch(bottommost).get_bottom()
527         # No bottommost patch, so just return HEAD
528         return git.get_head()
529
530     def get_parent_remote(self):
531         value = config.get('branch.%s.remote' % self.get_name())
532         if value:
533             return value
534         elif 'origin' in git.remotes_list():
535             out.note(('No parent remote declared for stack "%s",'
536                       ' defaulting to "origin".' % self.get_name()),
537                      ('Consider setting "branch.%s.remote" and'
538                       ' "branch.%s.merge" with "git config".'
539                       % (self.get_name(), self.get_name())))
540             return 'origin'
541         else:
542             raise StackException, 'Cannot find a parent remote for "%s"' % self.get_name()
543
544     def __set_parent_remote(self, remote):
545         value = config.set('branch.%s.remote' % self.get_name(), remote)
546
547     def get_parent_branch(self):
548         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
549         if value:
550             return value
551         elif git.rev_parse('heads/origin'):
552             out.note(('No parent branch declared for stack "%s",'
553                       ' defaulting to "heads/origin".' % self.get_name()),
554                      ('Consider setting "branch.%s.stgit.parentbranch"'
555                       ' with "git config".' % self.get_name()))
556             return 'heads/origin'
557         else:
558             raise StackException, 'Cannot find a parent branch for "%s"' % self.get_name()
559
560     def __set_parent_branch(self, name):
561         if config.get('branch.%s.remote' % self.get_name()):
562             # Never set merge if remote is not set to avoid
563             # possibly-erroneous lookups into 'origin'
564             config.set('branch.%s.merge' % self.get_name(), name)
565         config.set('branch.%s.stgit.parentbranch' % self.get_name(), name)
566
567     def set_parent(self, remote, localbranch):
568         if localbranch:
569             if remote:
570                 self.__set_parent_remote(remote)
571             self.__set_parent_branch(localbranch)
572         # We'll enforce this later
573 #         else:
574 #             raise StackException, 'Parent branch (%s) should be specified for %s' % localbranch, self.get_name()
575
576     def __patch_is_current(self, patch):
577         return patch.get_name() == self.get_current()
578
579     def patch_applied(self, name):
580         """Return true if the patch exists in the applied list
581         """
582         return name in self.get_applied()
583
584     def patch_unapplied(self, name):
585         """Return true if the patch exists in the unapplied list
586         """
587         return name in self.get_unapplied()
588
589     def patch_hidden(self, name):
590         """Return true if the patch is hidden.
591         """
592         return name in self.get_hidden()
593
594     def patch_exists(self, name):
595         """Return true if there is a patch with the given name, false
596         otherwise."""
597         return self.patch_applied(name) or self.patch_unapplied(name) \
598                or self.patch_hidden(name)
599
600     def init(self, create_at=False, parent_remote=None, parent_branch=None):
601         """Initialises the stgit series
602         """
603         if self.is_initialised():
604             raise StackException, '%s already initialized' % self.get_name()
605         for d in [self._dir()]:
606             if os.path.exists(d):
607                 raise StackException, '%s already exists' % d
608
609         if (create_at!=False):
610             git.create_branch(self.get_name(), create_at)
611
612         os.makedirs(self.__patch_dir)
613
614         self.set_parent(parent_remote, parent_branch)
615
616         self.create_empty_field('applied')
617         self.create_empty_field('unapplied')
618         self._set_field('orig-base', git.get_head())
619
620         config.set(self.format_version_key(), str(FORMAT_VERSION))
621
622     def rename(self, to_name):
623         """Renames a series
624         """
625         to_stack = Series(to_name)
626
627         if to_stack.is_initialised():
628             raise StackException, '"%s" already exists' % to_stack.get_name()
629
630         patches = self.get_applied() + self.get_unapplied()
631
632         git.rename_branch(self.get_name(), to_name)
633
634         for patch in patches:
635             git.rename_ref('refs/patches/%s/%s' % (self.get_name(), patch),
636                            'refs/patches/%s/%s' % (to_name, patch))
637             git.rename_ref('refs/patches/%s/%s.log' % (self.get_name(), patch),
638                            'refs/patches/%s/%s.log' % (to_name, patch))
639         if os.path.isdir(self._dir()):
640             rename(os.path.join(self._basedir(), 'patches'),
641                    self.get_name(), to_stack.get_name())
642
643         # Rename the config section
644         for k in ['branch.%s', 'branch.%s.stgit']:
645             config.rename_section(k % self.get_name(), k % to_name)
646
647         self.__init__(to_name)
648
649     def clone(self, target_series):
650         """Clones a series
651         """
652         try:
653             # allow cloning of branches not under StGIT control
654             base = self.get_base()
655         except:
656             base = git.get_head()
657         Series(target_series).init(create_at = base)
658         new_series = Series(target_series)
659
660         # generate an artificial description file
661         new_series.set_description('clone of "%s"' % self.get_name())
662
663         # clone self's entire series as unapplied patches
664         try:
665             # allow cloning of branches not under StGIT control
666             applied = self.get_applied()
667             unapplied = self.get_unapplied()
668             patches = applied + unapplied
669             patches.reverse()
670         except:
671             patches = applied = unapplied = []
672         for p in patches:
673             patch = self.get_patch(p)
674             newpatch = new_series.new_patch(p, message = patch.get_description(),
675                                             can_edit = False, unapplied = True,
676                                             bottom = patch.get_bottom(),
677                                             top = patch.get_top(),
678                                             author_name = patch.get_authname(),
679                                             author_email = patch.get_authemail(),
680                                             author_date = patch.get_authdate())
681             if patch.get_log():
682                 out.info('Setting log to %s' %  patch.get_log())
683                 newpatch.set_log(patch.get_log())
684             else:
685                 out.info('No log for %s' % p)
686
687         # fast forward the cloned series to self's top
688         new_series.forward_patches(applied)
689
690         # Clone parent informations
691         value = config.get('branch.%s.remote' % self.get_name())
692         if value:
693             config.set('branch.%s.remote' % target_series, value)
694
695         value = config.get('branch.%s.merge' % self.get_name())
696         if value:
697             config.set('branch.%s.merge' % target_series, value)
698
699         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
700         if value:
701             config.set('branch.%s.stgit.parentbranch' % target_series, value)
702
703     def delete(self, force = False):
704         """Deletes an stgit series
705         """
706         if self.is_initialised():
707             patches = self.get_unapplied() + self.get_applied()
708             if not force and patches:
709                 raise StackException, \
710                       'Cannot delete: the series still contains patches'
711             for p in patches:
712                 self.get_patch(p).delete()
713
714             # remove the trash directory if any
715             if os.path.exists(self.__trash_dir):
716                 for fname in os.listdir(self.__trash_dir):
717                     os.remove(os.path.join(self.__trash_dir, fname))
718                 os.rmdir(self.__trash_dir)
719
720             # FIXME: find a way to get rid of those manual removals
721             # (move functionality to StgitObject ?)
722             if os.path.exists(self.__applied_file):
723                 os.remove(self.__applied_file)
724             if os.path.exists(self.__unapplied_file):
725                 os.remove(self.__unapplied_file)
726             if os.path.exists(self.__hidden_file):
727                 os.remove(self.__hidden_file)
728             if os.path.exists(self._dir()+'/orig-base'):
729                 os.remove(self._dir()+'/orig-base')
730
731             if not os.listdir(self.__patch_dir):
732                 os.rmdir(self.__patch_dir)
733             else:
734                 out.warn('Patch directory %s is not empty' % self.__patch_dir)
735
736             try:
737                 os.removedirs(self._dir())
738             except OSError:
739                 raise StackException('Series directory %s is not empty'
740                                      % self._dir())
741
742             try:
743                 git.delete_branch(self.get_name())
744             except GitException:
745                 out.warn('Could not delete branch "%s"' % self.get_name())
746
747         config.remove_section('branch.%s' % self.get_name())
748         config.remove_section('branch.%s.stgit' % self.get_name())
749
750     def refresh_patch(self, files = None, message = None, edit = False,
751                       show_patch = False,
752                       cache_update = True,
753                       author_name = None, author_email = None,
754                       author_date = None,
755                       committer_name = None, committer_email = None,
756                       backup = False, sign_str = None, log = 'refresh',
757                       notes = None, bottom = None):
758         """Generates a new commit for the topmost patch
759         """
760         patch = self.get_current_patch()
761         if not patch:
762             raise StackException, 'No patches applied'
763
764         descr = patch.get_description()
765         if not (message or descr):
766             edit = True
767             descr = ''
768         elif message:
769             descr = message
770
771         # TODO: move this out of the stgit.stack module, it is really
772         # for higher level commands to handle the user interaction
773         if not message and edit:
774             descr = edit_file(self, descr.rstrip(), \
775                               'Please edit the description for patch "%s" ' \
776                               'above.' % patch.get_name(), show_patch)
777
778         if not author_name:
779             author_name = patch.get_authname()
780         if not author_email:
781             author_email = patch.get_authemail()
782         if not author_date:
783             author_date = patch.get_authdate()
784         if not committer_name:
785             committer_name = patch.get_commname()
786         if not committer_email:
787             committer_email = patch.get_commemail()
788
789         descr = add_sign_line(descr, sign_str, committer_name, committer_email)
790
791         if not bottom:
792             bottom = patch.get_bottom()
793
794         commit_id = git.commit(files = files,
795                                message = descr, parents = [bottom],
796                                cache_update = cache_update,
797                                allowempty = True,
798                                author_name = author_name,
799                                author_email = author_email,
800                                author_date = author_date,
801                                committer_name = committer_name,
802                                committer_email = committer_email)
803
804         patch.set_bottom(bottom, backup = backup)
805         patch.set_top(commit_id, backup = backup)
806         patch.set_description(descr)
807         patch.set_authname(author_name)
808         patch.set_authemail(author_email)
809         patch.set_authdate(author_date)
810         patch.set_commname(committer_name)
811         patch.set_commemail(committer_email)
812
813         if log:
814             self.log_patch(patch, log, notes)
815
816         return commit_id
817
818     def undo_refresh(self):
819         """Undo the patch boundaries changes caused by 'refresh'
820         """
821         name = self.get_current()
822         assert(name)
823
824         patch = self.get_patch(name)
825         old_bottom = patch.get_old_bottom()
826         old_top = patch.get_old_top()
827
828         # the bottom of the patch is not changed by refresh. If the
829         # old_bottom is different, there wasn't any previous 'refresh'
830         # command (probably only a 'push')
831         if old_bottom != patch.get_bottom() or old_top == patch.get_top():
832             raise StackException, 'No undo information available'
833
834         git.reset(tree_id = old_top, check_out = False)
835         if patch.restore_old_boundaries():
836             self.log_patch(patch, 'undo')
837
838     def new_patch(self, name, message = None, can_edit = True,
839                   unapplied = False, show_patch = False,
840                   top = None, bottom = None, commit = True,
841                   author_name = None, author_email = None, author_date = None,
842                   committer_name = None, committer_email = None,
843                   before_existing = False):
844         """Creates a new patch, either pointing to an existing commit object,
845         or by creating a new commit object.
846         """
847
848         assert commit or (top and bottom)
849         assert not before_existing or (top and bottom)
850         assert not (commit and before_existing)
851         assert (top and bottom) or (not top and not bottom)
852         assert not top or (bottom == git.get_commit(top).get_parent())
853
854         if name != None:
855             self.__patch_name_valid(name)
856             if self.patch_exists(name):
857                 raise StackException, 'Patch "%s" already exists' % name
858
859         # TODO: move this out of the stgit.stack module, it is really
860         # for higher level commands to handle the user interaction
861         if not message and can_edit:
862             descr = edit_file(
863                 self, None,
864                 'Please enter the description for the patch above.',
865                 show_patch)
866         else:
867             descr = message
868
869         head = git.get_head()
870
871         if name == None:
872             name = make_patch_name(descr, self.patch_exists)
873
874         patch = self.get_patch(name)
875         patch.create()
876
877         patch.set_description(descr)
878         patch.set_authname(author_name)
879         patch.set_authemail(author_email)
880         patch.set_authdate(author_date)
881         patch.set_commname(committer_name)
882         patch.set_commemail(committer_email)
883
884         if before_existing:
885             insert_string(self.__applied_file, patch.get_name())
886         elif unapplied:
887             patches = [patch.get_name()] + self.get_unapplied()
888             write_strings(self.__unapplied_file, patches)
889             set_head = False
890         else:
891             append_string(self.__applied_file, patch.get_name())
892             set_head = True
893
894         if commit:
895             if top:
896                 top_commit = git.get_commit(top)
897             else:
898                 bottom = head
899                 top_commit = git.get_commit(head)
900
901             # create a commit for the patch (may be empty if top == bottom);
902             # only commit on top of the current branch
903             assert(unapplied or bottom == head)
904             commit_id = git.commit(message = descr, parents = [bottom],
905                                    cache_update = False,
906                                    tree_id = top_commit.get_tree(),
907                                    allowempty = True, set_head = set_head,
908                                    author_name = author_name,
909                                    author_email = author_email,
910                                    author_date = author_date,
911                                    committer_name = committer_name,
912                                    committer_email = committer_email)
913             # set the patch top to the new commit
914             patch.set_bottom(bottom)
915             patch.set_top(commit_id)
916         else:
917             assert top != bottom
918             patch.set_bottom(bottom)
919             patch.set_top(top)
920
921         self.log_patch(patch, 'new')
922
923         return patch
924
925     def delete_patch(self, name):
926         """Deletes a patch
927         """
928         self.__patch_name_valid(name)
929         patch = self.get_patch(name)
930
931         if self.__patch_is_current(patch):
932             self.pop_patch(name)
933         elif self.patch_applied(name):
934             raise StackException, 'Cannot remove an applied patch, "%s", ' \
935                   'which is not current' % name
936         elif not name in self.get_unapplied():
937             raise StackException, 'Unknown patch "%s"' % name
938
939         # save the commit id to a trash file
940         write_string(os.path.join(self.__trash_dir, name), patch.get_top())
941
942         patch.delete()
943
944         unapplied = self.get_unapplied()
945         unapplied.remove(name)
946         write_strings(self.__unapplied_file, unapplied)
947
948     def forward_patches(self, names):
949         """Try to fast-forward an array of patches.
950
951         On return, patches in names[0:returned_value] have been pushed on the
952         stack. Apply the rest with push_patch
953         """
954         unapplied = self.get_unapplied()
955
956         forwarded = 0
957         top = git.get_head()
958
959         for name in names:
960             assert(name in unapplied)
961
962             patch = self.get_patch(name)
963
964             head = top
965             bottom = patch.get_bottom()
966             top = patch.get_top()
967
968             # top != bottom always since we have a commit for each patch
969             if head == bottom:
970                 # reset the backup information. No logging since the
971                 # patch hasn't changed
972                 patch.set_bottom(head, backup = True)
973                 patch.set_top(top, backup = True)
974
975             else:
976                 head_tree = git.get_commit(head).get_tree()
977                 bottom_tree = git.get_commit(bottom).get_tree()
978                 if head_tree == bottom_tree:
979                     # We must just reparent this patch and create a new commit
980                     # for it
981                     descr = patch.get_description()
982                     author_name = patch.get_authname()
983                     author_email = patch.get_authemail()
984                     author_date = patch.get_authdate()
985                     committer_name = patch.get_commname()
986                     committer_email = patch.get_commemail()
987
988                     top_tree = git.get_commit(top).get_tree()
989
990                     top = git.commit(message = descr, parents = [head],
991                                      cache_update = False,
992                                      tree_id = top_tree,
993                                      allowempty = True,
994                                      author_name = author_name,
995                                      author_email = author_email,
996                                      author_date = author_date,
997                                      committer_name = committer_name,
998                                      committer_email = committer_email)
999
1000                     patch.set_bottom(head, backup = True)
1001                     patch.set_top(top, backup = True)
1002
1003                     self.log_patch(patch, 'push(f)')
1004                 else:
1005                     top = head
1006                     # stop the fast-forwarding, must do a real merge
1007                     break
1008
1009             forwarded+=1
1010             unapplied.remove(name)
1011
1012         if forwarded == 0:
1013             return 0
1014
1015         git.switch(top)
1016
1017         append_strings(self.__applied_file, names[0:forwarded])
1018         write_strings(self.__unapplied_file, unapplied)
1019
1020         return forwarded
1021
1022     def merged_patches(self, names):
1023         """Test which patches were merged upstream by reverse-applying
1024         them in reverse order. The function returns the list of
1025         patches detected to have been applied. The state of the tree
1026         is restored to the original one
1027         """
1028         patches = [self.get_patch(name) for name in names]
1029         patches.reverse()
1030
1031         merged = []
1032         for p in patches:
1033             if git.apply_diff(p.get_top(), p.get_bottom()):
1034                 merged.append(p.get_name())
1035         merged.reverse()
1036
1037         git.reset()
1038
1039         return merged
1040
1041     def push_empty_patch(self, name):
1042         """Pushes an empty patch on the stack
1043         """
1044         unapplied = self.get_unapplied()
1045         assert(name in unapplied)
1046
1047         # patch = self.get_patch(name)
1048         head = git.get_head()
1049
1050         append_string(self.__applied_file, name)
1051
1052         unapplied.remove(name)
1053         write_strings(self.__unapplied_file, unapplied)
1054
1055         self.refresh_patch(bottom = head, cache_update = False, log = 'push(m)')
1056
1057     def push_patch(self, name):
1058         """Pushes a patch on the stack
1059         """
1060         unapplied = self.get_unapplied()
1061         assert(name in unapplied)
1062
1063         patch = self.get_patch(name)
1064
1065         head = git.get_head()
1066         bottom = patch.get_bottom()
1067         top = patch.get_top()
1068         # top != bottom always since we have a commit for each patch
1069
1070         if head == bottom:
1071             # A fast-forward push. Just reset the backup
1072             # information. No need for logging
1073             patch.set_bottom(bottom, backup = True)
1074             patch.set_top(top, backup = True)
1075
1076             git.switch(top)
1077             append_string(self.__applied_file, name)
1078
1079             unapplied.remove(name)
1080             write_strings(self.__unapplied_file, unapplied)
1081             return False
1082
1083         # Need to create a new commit an merge in the old patch
1084         ex = None
1085         modified = False
1086
1087         # Try the fast applying first. If this fails, fall back to the
1088         # three-way merge
1089         if not git.apply_diff(bottom, top):
1090             # if git.apply_diff() fails, the patch requires a diff3
1091             # merge and can be reported as modified
1092             modified = True
1093
1094             # merge can fail but the patch needs to be pushed
1095             try:
1096                 git.merge(bottom, head, top, recursive = True)
1097             except git.GitException, ex:
1098                 out.error('The merge failed during "push".',
1099                           'Use "refresh" after fixing the conflicts or'
1100                           ' revert the operation with "push --undo".')
1101
1102         append_string(self.__applied_file, name)
1103
1104         unapplied.remove(name)
1105         write_strings(self.__unapplied_file, unapplied)
1106
1107         if not ex:
1108             # if the merge was OK and no conflicts, just refresh the patch
1109             # The GIT cache was already updated by the merge operation
1110             if modified:
1111                 log = 'push(m)'
1112             else:
1113                 log = 'push'
1114             self.refresh_patch(bottom = head, cache_update = False, log = log)
1115         else:
1116             # we store the correctly merged files only for
1117             # tracking the conflict history. Note that the
1118             # git.merge() operations should always leave the index
1119             # in a valid state (i.e. only stage 0 files)
1120             self.refresh_patch(bottom = head, cache_update = False,
1121                                log = 'push(c)')
1122             raise StackException, str(ex)
1123
1124         return modified
1125
1126     def undo_push(self):
1127         name = self.get_current()
1128         assert(name)
1129
1130         patch = self.get_patch(name)
1131         old_bottom = patch.get_old_bottom()
1132         old_top = patch.get_old_top()
1133
1134         # the top of the patch is changed by a push operation only
1135         # together with the bottom (otherwise the top was probably
1136         # modified by 'refresh'). If they are both unchanged, there
1137         # was a fast forward
1138         if old_bottom == patch.get_bottom() and old_top != patch.get_top():
1139             raise StackException, 'No undo information available'
1140
1141         git.reset()
1142         self.pop_patch(name)
1143         ret = patch.restore_old_boundaries()
1144         if ret:
1145             self.log_patch(patch, 'undo')
1146
1147         return ret
1148
1149     def pop_patch(self, name, keep = False):
1150         """Pops the top patch from the stack
1151         """
1152         applied = self.get_applied()
1153         applied.reverse()
1154         assert(name in applied)
1155
1156         patch = self.get_patch(name)
1157
1158         if git.get_head_file() == self.get_name():
1159             if keep and not git.apply_diff(git.get_head(), patch.get_bottom()):
1160                 raise StackException(
1161                     'Failed to pop patches while preserving the local changes')
1162             git.switch(patch.get_bottom(), keep)
1163         else:
1164             git.set_branch(self.get_name(), patch.get_bottom())
1165
1166         # save the new applied list
1167         idx = applied.index(name) + 1
1168
1169         popped = applied[:idx]
1170         popped.reverse()
1171         unapplied = popped + self.get_unapplied()
1172         write_strings(self.__unapplied_file, unapplied)
1173
1174         del applied[:idx]
1175         applied.reverse()
1176         write_strings(self.__applied_file, applied)
1177
1178     def empty_patch(self, name):
1179         """Returns True if the patch is empty
1180         """
1181         self.__patch_name_valid(name)
1182         patch = self.get_patch(name)
1183         bottom = patch.get_bottom()
1184         top = patch.get_top()
1185
1186         if bottom == top:
1187             return True
1188         elif git.get_commit(top).get_tree() \
1189                  == git.get_commit(bottom).get_tree():
1190             return True
1191
1192         return False
1193
1194     def rename_patch(self, oldname, newname):
1195         self.__patch_name_valid(newname)
1196
1197         applied = self.get_applied()
1198         unapplied = self.get_unapplied()
1199
1200         if oldname == newname:
1201             raise StackException, '"To" name and "from" name are the same'
1202
1203         if newname in applied or newname in unapplied:
1204             raise StackException, 'Patch "%s" already exists' % newname
1205
1206         if oldname in unapplied:
1207             self.get_patch(oldname).rename(newname)
1208             unapplied[unapplied.index(oldname)] = newname
1209             write_strings(self.__unapplied_file, unapplied)
1210         elif oldname in applied:
1211             self.get_patch(oldname).rename(newname)
1212
1213             applied[applied.index(oldname)] = newname
1214             write_strings(self.__applied_file, applied)
1215         else:
1216             raise StackException, 'Unknown patch "%s"' % oldname
1217
1218     def log_patch(self, patch, message, notes = None):
1219         """Generate a log commit for a patch
1220         """
1221         top = git.get_commit(patch.get_top())
1222         old_log = patch.get_log()
1223
1224         if message is None:
1225             # replace the current log entry
1226             if not old_log:
1227                 raise StackException, \
1228                       'No log entry to annotate for patch "%s"' \
1229                       % patch.get_name()
1230             replace = True
1231             log_commit = git.get_commit(old_log)
1232             msg = log_commit.get_log().split('\n')[0]
1233             log_parent = log_commit.get_parent()
1234             if log_parent:
1235                 parents = [log_parent]
1236             else:
1237                 parents = []
1238         else:
1239             # generate a new log entry
1240             replace = False
1241             msg = '%s\t%s' % (message, top.get_id_hash())
1242             if old_log:
1243                 parents = [old_log]
1244             else:
1245                 parents = []
1246
1247         if notes:
1248             msg += '\n\n' + notes
1249
1250         log = git.commit(message = msg, parents = parents,
1251                          cache_update = False, tree_id = top.get_tree(),
1252                          allowempty = True)
1253         patch.set_log(log)
1254
1255     def hide_patch(self, name):
1256         """Add the patch to the hidden list.
1257         """
1258         unapplied = self.get_unapplied()
1259         if name not in unapplied:
1260             # keep the checking order for backward compatibility with
1261             # the old hidden patches functionality
1262             if self.patch_applied(name):
1263                 raise StackException, 'Cannot hide applied patch "%s"' % name
1264             elif self.patch_hidden(name):
1265                 raise StackException, 'Patch "%s" already hidden' % name
1266             else:
1267                 raise StackException, 'Unknown patch "%s"' % name
1268
1269         if not self.patch_hidden(name):
1270             # check needed for backward compatibility with the old
1271             # hidden patches functionality
1272             append_string(self.__hidden_file, name)
1273
1274         unapplied.remove(name)
1275         write_strings(self.__unapplied_file, unapplied)
1276
1277     def unhide_patch(self, name):
1278         """Remove the patch from the hidden list.
1279         """
1280         hidden = self.get_hidden()
1281         if not name in hidden:
1282             if self.patch_applied(name) or self.patch_unapplied(name):
1283                 raise StackException, 'Patch "%s" not hidden' % name
1284             else:
1285                 raise StackException, 'Unknown patch "%s"' % name
1286
1287         hidden.remove(name)
1288         write_strings(self.__hidden_file, hidden)
1289
1290         if not self.patch_applied(name) and not self.patch_unapplied(name):
1291             # check needed for backward compatibility with the old
1292             # hidden patches functionality
1293             append_string(self.__unapplied_file, name)