chiark / gitweb /
c7569b21ee4cd2029cad8b1541f3cd550c9e80de
[stgit] / stgit / stack.py
1 """Basic quilt-like functionality
2 """
3
4 __copyright__ = """
5 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License version 2 as
9 published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 """
20
21 import sys, os, re
22
23 from stgit.utils import *
24 from stgit import git, basedir, templates
25 from stgit.config import config
26 from shutil import copyfile
27
28
29 # stack exception class
30 class StackException(Exception):
31     pass
32
33 class FilterUntil:
34     def __init__(self):
35         self.should_print = True
36     def __call__(self, x, until_test, prefix):
37         if until_test(x):
38             self.should_print = False
39         if self.should_print:
40             return x[0:len(prefix)] != prefix
41         return False
42
43 #
44 # Functions
45 #
46 __comment_prefix = 'STG:'
47 __patch_prefix = 'STG_PATCH:'
48
49 def __clean_comments(f):
50     """Removes lines marked for status in a commit file
51     """
52     f.seek(0)
53
54     # remove status-prefixed lines
55     lines = f.readlines()
56
57     patch_filter = FilterUntil()
58     until_test = lambda t: t == (__patch_prefix + '\n')
59     lines = [l for l in lines if patch_filter(l, until_test, __comment_prefix)]
60
61     # remove empty lines at the end
62     while len(lines) != 0 and lines[-1] == '\n':
63         del lines[-1]
64
65     f.seek(0); f.truncate()
66     f.writelines(lines)
67
68 def edit_file(series, line, comment, show_patch = True):
69     fname = '.stgitmsg.txt'
70     tmpl = templates.get_template('patchdescr.tmpl')
71
72     f = file(fname, 'w+')
73     if line:
74         print >> f, line
75     elif tmpl:
76         print >> f, tmpl,
77     else:
78         print >> f
79     print >> f, __comment_prefix, comment
80     print >> f, __comment_prefix, \
81           'Lines prefixed with "%s" will be automatically removed.' \
82           % __comment_prefix
83     print >> f, __comment_prefix, \
84           'Trailing empty lines will be automatically removed.'
85
86     if show_patch:
87        print >> f, __patch_prefix
88        # series.get_patch(series.get_current()).get_top()
89        git.diff([], series.get_patch(series.get_current()).get_bottom(), None, f)
90
91     #Vim modeline must be near the end.
92     print >> f, __comment_prefix, 'vi: set textwidth=75 filetype=diff nobackup:'
93     f.close()
94
95     call_editor(fname)
96
97     f = file(fname, 'r+')
98
99     __clean_comments(f)
100     f.seek(0)
101     result = f.read()
102
103     f.close()
104     os.remove(fname)
105
106     return result
107
108 #
109 # Classes
110 #
111
112 class StgitObject:
113     """An object with stgit-like properties stored as files in a directory
114     """
115     def _set_dir(self, dir):
116         self.__dir = dir
117     def _dir(self):
118         return self.__dir
119
120     def create_empty_field(self, name):
121         create_empty_file(os.path.join(self.__dir, name))
122
123     def _get_field(self, name, multiline = False):
124         id_file = os.path.join(self.__dir, name)
125         if os.path.isfile(id_file):
126             line = read_string(id_file, multiline)
127             if line == '':
128                 return None
129             else:
130                 return line
131         else:
132             return None
133
134     def _set_field(self, name, value, multiline = False):
135         fname = os.path.join(self.__dir, name)
136         if value and value != '':
137             write_string(fname, value, multiline)
138         elif os.path.isfile(fname):
139             os.remove(fname)
140
141     
142 class Patch(StgitObject):
143     """Basic patch implementation
144     """
145     def __init__(self, name, series_dir, refs_dir):
146         self.__series_dir = series_dir
147         self.__name = name
148         self._set_dir(os.path.join(self.__series_dir, self.__name))
149         self.__refs_dir = refs_dir
150         self.__top_ref_file = os.path.join(self.__refs_dir, self.__name)
151         self.__log_ref_file = os.path.join(self.__refs_dir,
152                                            self.__name + '.log')
153
154     def create(self):
155         os.mkdir(self._dir())
156         self.create_empty_field('bottom')
157         self.create_empty_field('top')
158
159     def delete(self):
160         for f in os.listdir(self._dir()):
161             os.remove(os.path.join(self._dir(), f))
162         os.rmdir(self._dir())
163         os.remove(self.__top_ref_file)
164         if os.path.exists(self.__log_ref_file):
165             os.remove(self.__log_ref_file)
166
167     def get_name(self):
168         return self.__name
169
170     def rename(self, newname):
171         olddir = self._dir()
172         old_top_ref_file = self.__top_ref_file
173         old_log_ref_file = self.__log_ref_file
174         self.__name = newname
175         self._set_dir(os.path.join(self.__series_dir, self.__name))
176         self.__top_ref_file = os.path.join(self.__refs_dir, self.__name)
177         self.__log_ref_file = os.path.join(self.__refs_dir,
178                                            self.__name + '.log')
179
180         os.rename(olddir, self._dir())
181         os.rename(old_top_ref_file, self.__top_ref_file)
182         if os.path.exists(old_log_ref_file):
183             os.rename(old_log_ref_file, self.__log_ref_file)
184
185     def __update_top_ref(self, ref):
186         write_string(self.__top_ref_file, ref)
187
188     def __update_log_ref(self, ref):
189         write_string(self.__log_ref_file, ref)
190
191     def update_top_ref(self):
192         top = self.get_top()
193         if top:
194             self.__update_top_ref(top)
195
196     def get_old_bottom(self):
197         return self._get_field('bottom.old')
198
199     def get_bottom(self):
200         return self._get_field('bottom')
201
202     def set_bottom(self, value, backup = False):
203         if backup:
204             curr = self._get_field('bottom')
205             self._set_field('bottom.old', curr)
206         self._set_field('bottom', value)
207
208     def get_old_top(self):
209         return self._get_field('top.old')
210
211     def get_top(self):
212         return self._get_field('top')
213
214     def set_top(self, value, backup = False):
215         if backup:
216             curr = self._get_field('top')
217             self._set_field('top.old', curr)
218         self._set_field('top', value)
219         self.__update_top_ref(value)
220
221     def restore_old_boundaries(self):
222         bottom = self._get_field('bottom.old')
223         top = self._get_field('top.old')
224
225         if top and bottom:
226             self._set_field('bottom', bottom)
227             self._set_field('top', top)
228             self.__update_top_ref(top)
229             return True
230         else:
231             return False
232
233     def get_description(self):
234         return self._get_field('description', True)
235
236     def set_description(self, line):
237         self._set_field('description', line, True)
238
239     def get_authname(self):
240         return self._get_field('authname')
241
242     def set_authname(self, name):
243         self._set_field('authname', name or git.author().name)
244
245     def get_authemail(self):
246         return self._get_field('authemail')
247
248     def set_authemail(self, email):
249         self._set_field('authemail', email or git.author().email)
250
251     def get_authdate(self):
252         return self._get_field('authdate')
253
254     def set_authdate(self, date):
255         self._set_field('authdate', date or git.author().date)
256
257     def get_commname(self):
258         return self._get_field('commname')
259
260     def set_commname(self, name):
261         self._set_field('commname', name or git.committer().name)
262
263     def get_commemail(self):
264         return self._get_field('commemail')
265
266     def set_commemail(self, email):
267         self._set_field('commemail', email or git.committer().email)
268
269     def get_log(self):
270         return self._get_field('log')
271
272     def set_log(self, value, backup = False):
273         self._set_field('log', value)
274         self.__update_log_ref(value)
275
276 # The current StGIT metadata format version.
277 FORMAT_VERSION = 2
278
279 class PatchSet(StgitObject):
280     def __init__(self, name = None):
281         try:
282             if name:
283                 self.set_name (name)
284             else:
285                 self.set_name (git.get_head_file())
286             self.__base_dir = basedir.get()
287         except git.GitException, ex:
288             raise StackException, 'GIT tree not initialised: %s' % ex
289
290         self._set_dir(os.path.join(self.__base_dir, 'patches', self.get_name()))
291
292     def get_name(self):
293         return self.__name
294     def set_name(self, name):
295         self.__name = name
296
297     def _basedir(self):
298         return self.__base_dir
299
300     def get_head(self):
301         """Return the head of the branch
302         """
303         crt = self.get_current_patch()
304         if crt:
305             return crt.get_top()
306         else:
307             return self.get_base()
308
309     def get_protected(self):
310         return os.path.isfile(os.path.join(self._dir(), 'protected'))
311
312     def protect(self):
313         protect_file = os.path.join(self._dir(), 'protected')
314         if not os.path.isfile(protect_file):
315             create_empty_file(protect_file)
316
317     def unprotect(self):
318         protect_file = os.path.join(self._dir(), 'protected')
319         if os.path.isfile(protect_file):
320             os.remove(protect_file)
321
322     def __branch_descr(self):
323         return 'branch.%s.description' % self.get_name()
324
325     def get_description(self):
326         return config.get(self.__branch_descr()) or ''
327
328     def set_description(self, line):
329         if line:
330             config.set(self.__branch_descr(), line)
331         else:
332             config.unset(self.__branch_descr())
333
334     def head_top_equal(self):
335         """Return true if the head and the top are the same
336         """
337         crt = self.get_current_patch()
338         if not crt:
339             # we don't care, no patches applied
340             return True
341         return git.get_head() == crt.get_top()
342
343     def is_initialised(self):
344         """Checks if series is already initialised
345         """
346         return bool(config.get(self.format_version_key()))
347
348
349 class Series(PatchSet):
350     """Class including the operations on series
351     """
352     def __init__(self, name = None):
353         """Takes a series name as the parameter.
354         """
355         PatchSet.__init__(self, name)
356
357         # Update the branch to the latest format version if it is
358         # initialized, but don't touch it if it isn't.
359         self.update_to_current_format_version()
360
361         self.__refs_dir = os.path.join(self._basedir(), 'refs', 'patches',
362                                        self.get_name())
363
364         self.__applied_file = os.path.join(self._dir(), 'applied')
365         self.__unapplied_file = os.path.join(self._dir(), 'unapplied')
366         self.__hidden_file = os.path.join(self._dir(), 'hidden')
367
368         # where this series keeps its patches
369         self.__patch_dir = os.path.join(self._dir(), 'patches')
370
371         # trash directory
372         self.__trash_dir = os.path.join(self._dir(), 'trash')
373
374     def format_version_key(self):
375         return 'branch.%s.stgit.stackformatversion' % self.get_name()
376
377     def update_to_current_format_version(self):
378         """Update a potentially older StGIT directory structure to the
379         latest version. Note: This function should depend as little as
380         possible on external functions that may change during a format
381         version bump, since it must remain able to process older formats."""
382
383         branch_dir = os.path.join(self._basedir(), 'patches', self.get_name())
384         def get_format_version():
385             """Return the integer format version number, or None if the
386             branch doesn't have any StGIT metadata at all, of any version."""
387             fv = config.get(self.format_version_key())
388             ofv = config.get('branch.%s.stgitformatversion' % self.get_name())
389             if fv:
390                 # Great, there's an explicitly recorded format version
391                 # number, which means that the branch is initialized and
392                 # of that exact version.
393                 return int(fv)
394             elif ofv:
395                 # Old name for the version info, upgrade it
396                 config.set(self.format_version_key(), ofv)
397                 config.unset('branch.%s.stgitformatversion' % self.get_name())
398                 return int(ofv)
399             elif os.path.isdir(os.path.join(branch_dir, 'patches')):
400                 # There's a .git/patches/<branch>/patches dirctory, which
401                 # means this is an initialized version 1 branch.
402                 return 1
403             elif os.path.isdir(branch_dir):
404                 # There's a .git/patches/<branch> directory, which means
405                 # this is an initialized version 0 branch.
406                 return 0
407             else:
408                 # The branch doesn't seem to be initialized at all.
409                 return None
410         def set_format_version(v):
411             out.info('Upgraded branch %s to format version %d' % (self.get_name(), v))
412             config.set(self.format_version_key(), '%d' % v)
413         def mkdir(d):
414             if not os.path.isdir(d):
415                 os.makedirs(d)
416         def rm(f):
417             if os.path.exists(f):
418                 os.remove(f)
419
420         # Update 0 -> 1.
421         if get_format_version() == 0:
422             mkdir(os.path.join(branch_dir, 'trash'))
423             patch_dir = os.path.join(branch_dir, 'patches')
424             mkdir(patch_dir)
425             refs_dir = os.path.join(self._basedir(), 'refs', 'patches', self.get_name())
426             mkdir(refs_dir)
427             for patch in (file(os.path.join(branch_dir, 'unapplied')).readlines()
428                           + file(os.path.join(branch_dir, 'applied')).readlines()):
429                 patch = patch.strip()
430                 os.rename(os.path.join(branch_dir, patch),
431                           os.path.join(patch_dir, patch))
432                 Patch(patch, patch_dir, refs_dir).update_top_ref()
433             set_format_version(1)
434
435         # Update 1 -> 2.
436         if get_format_version() == 1:
437             desc_file = os.path.join(branch_dir, 'description')
438             if os.path.isfile(desc_file):
439                 desc = read_string(desc_file)
440                 if desc:
441                     config.set('branch.%s.description' % self.get_name(), desc)
442                 rm(desc_file)
443             rm(os.path.join(branch_dir, 'current'))
444             rm(os.path.join(self._basedir(), 'refs', 'bases', self.get_name()))
445             set_format_version(2)
446
447         # Make sure we're at the latest version.
448         if not get_format_version() in [None, FORMAT_VERSION]:
449             raise StackException('Branch %s is at format version %d, expected %d'
450                                  % (self.get_name(), get_format_version(), FORMAT_VERSION))
451
452     def __patch_name_valid(self, name):
453         """Raise an exception if the patch name is not valid.
454         """
455         if not name or re.search('[^\w.-]', name):
456             raise StackException, 'Invalid patch name: "%s"' % name
457
458     def get_patch(self, name):
459         """Return a Patch object for the given name
460         """
461         return Patch(name, self.__patch_dir, self.__refs_dir)
462
463     def get_current_patch(self):
464         """Return a Patch object representing the topmost patch, or
465         None if there is no such patch."""
466         crt = self.get_current()
467         if not crt:
468             return None
469         return self.get_patch(crt)
470
471     def get_current(self):
472         """Return the name of the topmost patch, or None if there is
473         no such patch."""
474         try:
475             applied = self.get_applied()
476         except StackException:
477             # No "applied" file: branch is not initialized.
478             return None
479         try:
480             return applied[-1]
481         except IndexError:
482             # No patches applied.
483             return None
484
485     def get_applied(self):
486         if not os.path.isfile(self.__applied_file):
487             raise StackException, 'Branch "%s" not initialised' % self.get_name()
488         return read_strings(self.__applied_file)
489
490     def get_unapplied(self):
491         if not os.path.isfile(self.__unapplied_file):
492             raise StackException, 'Branch "%s" not initialised' % self.get_name()
493         return read_strings(self.__unapplied_file)
494
495     def get_hidden(self):
496         if not os.path.isfile(self.__hidden_file):
497             return []
498         return read_strings(self.__hidden_file)
499
500     def get_base(self):
501         # Return the parent of the bottommost patch, if there is one.
502         if os.path.isfile(self.__applied_file):
503             bottommost = file(self.__applied_file).readline().strip()
504             if bottommost:
505                 return self.get_patch(bottommost).get_bottom()
506         # No bottommost patch, so just return HEAD
507         return git.get_head()
508
509     def get_parent_remote(self):
510         value = config.get('branch.%s.remote' % self.get_name())
511         if value:
512             return value
513         elif 'origin' in git.remotes_list():
514             out.note(('No parent remote declared for stack "%s",'
515                       ' defaulting to "origin".' % self.get_name()),
516                      ('Consider setting "branch.%s.remote" and'
517                       ' "branch.%s.merge" with "git repo-config".'
518                       % (self.get_name(), self.get_name())))
519             return 'origin'
520         else:
521             raise StackException, 'Cannot find a parent remote for "%s"' % self.get_name()
522
523     def __set_parent_remote(self, remote):
524         value = config.set('branch.%s.remote' % self.get_name(), remote)
525
526     def get_parent_branch(self):
527         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
528         if value:
529             return value
530         elif git.rev_parse('heads/origin'):
531             out.note(('No parent branch declared for stack "%s",'
532                       ' defaulting to "heads/origin".' % self.get_name()),
533                      ('Consider setting "branch.%s.stgit.parentbranch"'
534                       ' with "git repo-config".' % self.get_name()))
535             return 'heads/origin'
536         else:
537             raise StackException, 'Cannot find a parent branch for "%s"' % self.get_name()
538
539     def __set_parent_branch(self, name):
540         if config.get('branch.%s.remote' % self.get_name()):
541             # Never set merge if remote is not set to avoid
542             # possibly-erroneous lookups into 'origin'
543             config.set('branch.%s.merge' % self.get_name(), name)
544         config.set('branch.%s.stgit.parentbranch' % self.get_name(), name)
545
546     def set_parent(self, remote, localbranch):
547         if localbranch:
548             self.__set_parent_remote(remote)
549             self.__set_parent_branch(localbranch)
550         # We'll enforce this later
551 #         else:
552 #             raise StackException, 'Parent branch (%s) should be specified for %s' % localbranch, self.get_name()
553
554     def __patch_is_current(self, patch):
555         return patch.get_name() == self.get_current()
556
557     def patch_applied(self, name):
558         """Return true if the patch exists in the applied list
559         """
560         return name in self.get_applied()
561
562     def patch_unapplied(self, name):
563         """Return true if the patch exists in the unapplied list
564         """
565         return name in self.get_unapplied()
566
567     def patch_hidden(self, name):
568         """Return true if the patch is hidden.
569         """
570         return name in self.get_hidden()
571
572     def patch_exists(self, name):
573         """Return true if there is a patch with the given name, false
574         otherwise."""
575         return self.patch_applied(name) or self.patch_unapplied(name) \
576                or self.patch_hidden(name)
577
578     def init(self, create_at=False, parent_remote=None, parent_branch=None):
579         """Initialises the stgit series
580         """
581         if self.is_initialised():
582             raise StackException, '%s already initialized' % self.get_name()
583         for d in [self._dir(), self.__refs_dir]:
584             if os.path.exists(d):
585                 raise StackException, '%s already exists' % d
586
587         if (create_at!=False):
588             git.create_branch(self.get_name(), create_at)
589
590         os.makedirs(self.__patch_dir)
591
592         self.set_parent(parent_remote, parent_branch)
593
594         self.create_empty_field('applied')
595         self.create_empty_field('unapplied')
596         os.makedirs(self.__refs_dir)
597         self._set_field('orig-base', git.get_head())
598
599         config.set(self.format_version_key(), str(FORMAT_VERSION))
600
601     def rename(self, to_name):
602         """Renames a series
603         """
604         to_stack = Series(to_name)
605
606         if to_stack.is_initialised():
607             raise StackException, '"%s" already exists' % to_stack.get_name()
608
609         git.rename_branch(self.get_name(), to_name)
610
611         if os.path.isdir(self._dir()):
612             rename(os.path.join(self._basedir(), 'patches'),
613                    self.get_name(), to_stack.get_name())
614         if os.path.exists(self.__refs_dir):
615             rename(os.path.join(self._basedir(), 'refs', 'patches'),
616                    self.get_name(), to_stack.get_name())
617
618         # Rename the config section
619         for k in ['branch.%s', 'branch.%s.stgit']:
620             config.rename_section(k % self.get_name(), k % to_name)
621
622         self.__init__(to_name)
623
624     def clone(self, target_series):
625         """Clones a series
626         """
627         try:
628             # allow cloning of branches not under StGIT control
629             base = self.get_base()
630         except:
631             base = git.get_head()
632         Series(target_series).init(create_at = base)
633         new_series = Series(target_series)
634
635         # generate an artificial description file
636         new_series.set_description('clone of "%s"' % self.get_name())
637
638         # clone self's entire series as unapplied patches
639         try:
640             # allow cloning of branches not under StGIT control
641             applied = self.get_applied()
642             unapplied = self.get_unapplied()
643             patches = applied + unapplied
644             patches.reverse()
645         except:
646             patches = applied = unapplied = []
647         for p in patches:
648             patch = self.get_patch(p)
649             newpatch = new_series.new_patch(p, message = patch.get_description(),
650                                             can_edit = False, unapplied = True,
651                                             bottom = patch.get_bottom(),
652                                             top = patch.get_top(),
653                                             author_name = patch.get_authname(),
654                                             author_email = patch.get_authemail(),
655                                             author_date = patch.get_authdate())
656             if patch.get_log():
657                 out.info('Setting log to %s' %  patch.get_log())
658                 newpatch.set_log(patch.get_log())
659             else:
660                 out.info('No log for %s' % p)
661
662         # fast forward the cloned series to self's top
663         new_series.forward_patches(applied)
664
665         # Clone parent informations
666         value = config.get('branch.%s.remote' % self.get_name())
667         if value:
668             config.set('branch.%s.remote' % target_series, value)
669
670         value = config.get('branch.%s.merge' % self.get_name())
671         if value:
672             config.set('branch.%s.merge' % target_series, value)
673
674         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
675         if value:
676             config.set('branch.%s.stgit.parentbranch' % target_series, value)
677
678     def delete(self, force = False):
679         """Deletes an stgit series
680         """
681         if self.is_initialised():
682             patches = self.get_unapplied() + self.get_applied()
683             if not force and patches:
684                 raise StackException, \
685                       'Cannot delete: the series still contains patches'
686             for p in patches:
687                 self.get_patch(p).delete()
688
689             # remove the trash directory if any
690             if os.path.exists(self.__trash_dir):
691                 for fname in os.listdir(self.__trash_dir):
692                     os.remove(os.path.join(self.__trash_dir, fname))
693                 os.rmdir(self.__trash_dir)
694
695             # FIXME: find a way to get rid of those manual removals
696             # (move functionality to StgitObject ?)
697             if os.path.exists(self.__applied_file):
698                 os.remove(self.__applied_file)
699             if os.path.exists(self.__unapplied_file):
700                 os.remove(self.__unapplied_file)
701             if os.path.exists(self.__hidden_file):
702                 os.remove(self.__hidden_file)
703             if os.path.exists(self._dir()+'/orig-base'):
704                 os.remove(self._dir()+'/orig-base')
705
706             if not os.listdir(self.__patch_dir):
707                 os.rmdir(self.__patch_dir)
708             else:
709                 out.warn('Patch directory %s is not empty' % self.__patch_dir)
710
711             try:
712                 os.removedirs(self._dir())
713             except OSError:
714                 raise StackException('Series directory %s is not empty'
715                                      % self._dir())
716
717             try:
718                 os.removedirs(self.__refs_dir)
719             except OSError:
720                 out.warn('Refs directory %s is not empty' % self.__refs_dir)
721
722         # Cleanup parent informations
723         # FIXME: should one day make use of git-config --section-remove,
724         # scheduled for 1.5.1
725         config.unset('branch.%s.remote' % self.get_name())
726         config.unset('branch.%s.merge' % self.get_name())
727         config.unset('branch.%s.stgit.parentbranch' % self.get_name())
728         config.unset(self.format_version_key())
729
730     def refresh_patch(self, files = None, message = None, edit = False,
731                       show_patch = False,
732                       cache_update = True,
733                       author_name = None, author_email = None,
734                       author_date = None,
735                       committer_name = None, committer_email = None,
736                       backup = False, sign_str = None, log = 'refresh',
737                       notes = None):
738         """Generates a new commit for the given patch
739         """
740         name = self.get_current()
741         if not name:
742             raise StackException, 'No patches applied'
743
744         patch = self.get_patch(name)
745
746         descr = patch.get_description()
747         if not (message or descr):
748             edit = True
749             descr = ''
750         elif message:
751             descr = message
752
753         if not message and edit:
754             descr = edit_file(self, descr.rstrip(), \
755                               'Please edit the description for patch "%s" ' \
756                               'above.' % name, show_patch)
757
758         if not author_name:
759             author_name = patch.get_authname()
760         if not author_email:
761             author_email = patch.get_authemail()
762         if not author_date:
763             author_date = patch.get_authdate()
764         if not committer_name:
765             committer_name = patch.get_commname()
766         if not committer_email:
767             committer_email = patch.get_commemail()
768
769         if sign_str:
770             descr = descr.rstrip()
771             if descr.find("\nSigned-off-by:") < 0 \
772                and descr.find("\nAcked-by:") < 0:
773                 descr = descr + "\n"
774
775             descr = '%s\n%s: %s <%s>\n' % (descr, sign_str,
776                                            committer_name, committer_email)
777
778         bottom = patch.get_bottom()
779
780         commit_id = git.commit(files = files,
781                                message = descr, parents = [bottom],
782                                cache_update = cache_update,
783                                allowempty = True,
784                                author_name = author_name,
785                                author_email = author_email,
786                                author_date = author_date,
787                                committer_name = committer_name,
788                                committer_email = committer_email)
789
790         patch.set_bottom(bottom, backup = backup)
791         patch.set_top(commit_id, backup = backup)
792         patch.set_description(descr)
793         patch.set_authname(author_name)
794         patch.set_authemail(author_email)
795         patch.set_authdate(author_date)
796         patch.set_commname(committer_name)
797         patch.set_commemail(committer_email)
798
799         if log:
800             self.log_patch(patch, log, notes)
801
802         return commit_id
803
804     def undo_refresh(self):
805         """Undo the patch boundaries changes caused by 'refresh'
806         """
807         name = self.get_current()
808         assert(name)
809
810         patch = self.get_patch(name)
811         old_bottom = patch.get_old_bottom()
812         old_top = patch.get_old_top()
813
814         # the bottom of the patch is not changed by refresh. If the
815         # old_bottom is different, there wasn't any previous 'refresh'
816         # command (probably only a 'push')
817         if old_bottom != patch.get_bottom() or old_top == patch.get_top():
818             raise StackException, 'No undo information available'
819
820         git.reset(tree_id = old_top, check_out = False)
821         if patch.restore_old_boundaries():
822             self.log_patch(patch, 'undo')
823
824     def new_patch(self, name, message = None, can_edit = True,
825                   unapplied = False, show_patch = False,
826                   top = None, bottom = None, commit = True,
827                   author_name = None, author_email = None, author_date = None,
828                   committer_name = None, committer_email = None,
829                   before_existing = False):
830         """Creates a new patch
831         """
832
833         if name != None:
834             self.__patch_name_valid(name)
835             if self.patch_exists(name):
836                 raise StackException, 'Patch "%s" already exists' % name
837
838         if not message and can_edit:
839             descr = edit_file(
840                 self, None,
841                 'Please enter the description for the patch above.',
842                 show_patch)
843         else:
844             descr = message
845
846         head = git.get_head()
847
848         if name == None:
849             name = make_patch_name(descr, self.patch_exists)
850
851         patch = self.get_patch(name)
852         patch.create()
853
854         if not bottom:
855             bottom = head
856         if not top:
857             top = head
858
859         patch.set_bottom(bottom)
860         patch.set_top(top)
861         patch.set_description(descr)
862         patch.set_authname(author_name)
863         patch.set_authemail(author_email)
864         patch.set_authdate(author_date)
865         patch.set_commname(committer_name)
866         patch.set_commemail(committer_email)
867
868         if before_existing:
869             insert_string(self.__applied_file, patch.get_name())
870             # no need to commit anything as the object is already
871             # present (mainly used by 'uncommit')
872             commit = False
873         elif unapplied:
874             patches = [patch.get_name()] + self.get_unapplied()
875             write_strings(self.__unapplied_file, patches)
876             set_head = False
877         else:
878             append_string(self.__applied_file, patch.get_name())
879             set_head = True
880
881         if commit:
882             # create a commit for the patch (may be empty if top == bottom);
883             # only commit on top of the current branch
884             assert(unapplied or bottom == head)
885             top_commit = git.get_commit(top)
886             commit_id = git.commit(message = descr, parents = [bottom],
887                                    cache_update = False,
888                                    tree_id = top_commit.get_tree(),
889                                    allowempty = True, set_head = set_head,
890                                    author_name = author_name,
891                                    author_email = author_email,
892                                    author_date = author_date,
893                                    committer_name = committer_name,
894                                    committer_email = committer_email)
895             # set the patch top to the new commit
896             patch.set_top(commit_id)
897
898         self.log_patch(patch, 'new')
899
900         return patch
901
902     def delete_patch(self, name):
903         """Deletes a patch
904         """
905         self.__patch_name_valid(name)
906         patch = self.get_patch(name)
907
908         if self.__patch_is_current(patch):
909             self.pop_patch(name)
910         elif self.patch_applied(name):
911             raise StackException, 'Cannot remove an applied patch, "%s", ' \
912                   'which is not current' % name
913         elif not name in self.get_unapplied():
914             raise StackException, 'Unknown patch "%s"' % name
915
916         # save the commit id to a trash file
917         write_string(os.path.join(self.__trash_dir, name), patch.get_top())
918
919         patch.delete()
920
921         unapplied = self.get_unapplied()
922         unapplied.remove(name)
923         write_strings(self.__unapplied_file, unapplied)
924
925     def forward_patches(self, names):
926         """Try to fast-forward an array of patches.
927
928         On return, patches in names[0:returned_value] have been pushed on the
929         stack. Apply the rest with push_patch
930         """
931         unapplied = self.get_unapplied()
932
933         forwarded = 0
934         top = git.get_head()
935
936         for name in names:
937             assert(name in unapplied)
938
939             patch = self.get_patch(name)
940
941             head = top
942             bottom = patch.get_bottom()
943             top = patch.get_top()
944
945             # top != bottom always since we have a commit for each patch
946             if head == bottom:
947                 # reset the backup information. No logging since the
948                 # patch hasn't changed
949                 patch.set_bottom(head, backup = True)
950                 patch.set_top(top, backup = True)
951
952             else:
953                 head_tree = git.get_commit(head).get_tree()
954                 bottom_tree = git.get_commit(bottom).get_tree()
955                 if head_tree == bottom_tree:
956                     # We must just reparent this patch and create a new commit
957                     # for it
958                     descr = patch.get_description()
959                     author_name = patch.get_authname()
960                     author_email = patch.get_authemail()
961                     author_date = patch.get_authdate()
962                     committer_name = patch.get_commname()
963                     committer_email = patch.get_commemail()
964
965                     top_tree = git.get_commit(top).get_tree()
966
967                     top = git.commit(message = descr, parents = [head],
968                                      cache_update = False,
969                                      tree_id = top_tree,
970                                      allowempty = True,
971                                      author_name = author_name,
972                                      author_email = author_email,
973                                      author_date = author_date,
974                                      committer_name = committer_name,
975                                      committer_email = committer_email)
976
977                     patch.set_bottom(head, backup = True)
978                     patch.set_top(top, backup = True)
979
980                     self.log_patch(patch, 'push(f)')
981                 else:
982                     top = head
983                     # stop the fast-forwarding, must do a real merge
984                     break
985
986             forwarded+=1
987             unapplied.remove(name)
988
989         if forwarded == 0:
990             return 0
991
992         git.switch(top)
993
994         append_strings(self.__applied_file, names[0:forwarded])
995         write_strings(self.__unapplied_file, unapplied)
996
997         return forwarded
998
999     def merged_patches(self, names):
1000         """Test which patches were merged upstream by reverse-applying
1001         them in reverse order. The function returns the list of
1002         patches detected to have been applied. The state of the tree
1003         is restored to the original one
1004         """
1005         patches = [self.get_patch(name) for name in names]
1006         patches.reverse()
1007
1008         merged = []
1009         for p in patches:
1010             if git.apply_diff(p.get_top(), p.get_bottom()):
1011                 merged.append(p.get_name())
1012         merged.reverse()
1013
1014         git.reset()
1015
1016         return merged
1017
1018     def push_patch(self, name, empty = False):
1019         """Pushes a patch on the stack
1020         """
1021         unapplied = self.get_unapplied()
1022         assert(name in unapplied)
1023
1024         patch = self.get_patch(name)
1025
1026         head = git.get_head()
1027         bottom = patch.get_bottom()
1028         top = patch.get_top()
1029
1030         ex = None
1031         modified = False
1032
1033         # top != bottom always since we have a commit for each patch
1034         if empty:
1035             # just make an empty patch (top = bottom = HEAD). This
1036             # option is useful to allow undoing already merged
1037             # patches. The top is updated by refresh_patch since we
1038             # need an empty commit
1039             patch.set_bottom(head, backup = True)
1040             patch.set_top(head, backup = True)
1041             modified = True
1042         elif head == bottom:
1043             # reset the backup information. No need for logging
1044             patch.set_bottom(bottom, backup = True)
1045             patch.set_top(top, backup = True)
1046
1047             git.switch(top)
1048         else:
1049             # new patch needs to be refreshed.
1050             # The current patch is empty after merge.
1051             patch.set_bottom(head, backup = True)
1052             patch.set_top(head, backup = True)
1053
1054             # Try the fast applying first. If this fails, fall back to the
1055             # three-way merge
1056             if not git.apply_diff(bottom, top):
1057                 # if git.apply_diff() fails, the patch requires a diff3
1058                 # merge and can be reported as modified
1059                 modified = True
1060
1061                 # merge can fail but the patch needs to be pushed
1062                 try:
1063                     git.merge(bottom, head, top, recursive = True)
1064                 except git.GitException, ex:
1065                     out.error('The merge failed during "push".',
1066                               'Use "refresh" after fixing the conflicts or'
1067                               ' revert the operation with "push --undo".')
1068
1069         append_string(self.__applied_file, name)
1070
1071         unapplied.remove(name)
1072         write_strings(self.__unapplied_file, unapplied)
1073
1074         # head == bottom case doesn't need to refresh the patch
1075         if empty or head != bottom:
1076             if not ex:
1077                 # if the merge was OK and no conflicts, just refresh the patch
1078                 # The GIT cache was already updated by the merge operation
1079                 if modified:
1080                     log = 'push(m)'
1081                 else:
1082                     log = 'push'
1083                 self.refresh_patch(cache_update = False, log = log)
1084             else:
1085                 # we store the correctly merged files only for
1086                 # tracking the conflict history. Note that the
1087                 # git.merge() operations should always leave the index
1088                 # in a valid state (i.e. only stage 0 files)
1089                 self.refresh_patch(cache_update = False, log = 'push(c)')
1090                 raise StackException, str(ex)
1091
1092         return modified
1093
1094     def undo_push(self):
1095         name = self.get_current()
1096         assert(name)
1097
1098         patch = self.get_patch(name)
1099         old_bottom = patch.get_old_bottom()
1100         old_top = patch.get_old_top()
1101
1102         # the top of the patch is changed by a push operation only
1103         # together with the bottom (otherwise the top was probably
1104         # modified by 'refresh'). If they are both unchanged, there
1105         # was a fast forward
1106         if old_bottom == patch.get_bottom() and old_top != patch.get_top():
1107             raise StackException, 'No undo information available'
1108
1109         git.reset()
1110         self.pop_patch(name)
1111         ret = patch.restore_old_boundaries()
1112         if ret:
1113             self.log_patch(patch, 'undo')
1114
1115         return ret
1116
1117     def pop_patch(self, name, keep = False):
1118         """Pops the top patch from the stack
1119         """
1120         applied = self.get_applied()
1121         applied.reverse()
1122         assert(name in applied)
1123
1124         patch = self.get_patch(name)
1125
1126         if git.get_head_file() == self.get_name():
1127             if keep and not git.apply_diff(git.get_head(), patch.get_bottom()):
1128                 raise StackException(
1129                     'Failed to pop patches while preserving the local changes')
1130             git.switch(patch.get_bottom(), keep)
1131         else:
1132             git.set_branch(self.get_name(), patch.get_bottom())
1133
1134         # save the new applied list
1135         idx = applied.index(name) + 1
1136
1137         popped = applied[:idx]
1138         popped.reverse()
1139         unapplied = popped + self.get_unapplied()
1140         write_strings(self.__unapplied_file, unapplied)
1141
1142         del applied[:idx]
1143         applied.reverse()
1144         write_strings(self.__applied_file, applied)
1145
1146     def empty_patch(self, name):
1147         """Returns True if the patch is empty
1148         """
1149         self.__patch_name_valid(name)
1150         patch = self.get_patch(name)
1151         bottom = patch.get_bottom()
1152         top = patch.get_top()
1153
1154         if bottom == top:
1155             return True
1156         elif git.get_commit(top).get_tree() \
1157                  == git.get_commit(bottom).get_tree():
1158             return True
1159
1160         return False
1161
1162     def rename_patch(self, oldname, newname):
1163         self.__patch_name_valid(newname)
1164
1165         applied = self.get_applied()
1166         unapplied = self.get_unapplied()
1167
1168         if oldname == newname:
1169             raise StackException, '"To" name and "from" name are the same'
1170
1171         if newname in applied or newname in unapplied:
1172             raise StackException, 'Patch "%s" already exists' % newname
1173
1174         if oldname in unapplied:
1175             self.get_patch(oldname).rename(newname)
1176             unapplied[unapplied.index(oldname)] = newname
1177             write_strings(self.__unapplied_file, unapplied)
1178         elif oldname in applied:
1179             self.get_patch(oldname).rename(newname)
1180
1181             applied[applied.index(oldname)] = newname
1182             write_strings(self.__applied_file, applied)
1183         else:
1184             raise StackException, 'Unknown patch "%s"' % oldname
1185
1186     def log_patch(self, patch, message, notes = None):
1187         """Generate a log commit for a patch
1188         """
1189         top = git.get_commit(patch.get_top())
1190         old_log = patch.get_log()
1191
1192         if message is None:
1193             # replace the current log entry
1194             if not old_log:
1195                 raise StackException, \
1196                       'No log entry to annotate for patch "%s"' \
1197                       % patch.get_name()
1198             replace = True
1199             log_commit = git.get_commit(old_log)
1200             msg = log_commit.get_log().split('\n')[0]
1201             log_parent = log_commit.get_parent()
1202             if log_parent:
1203                 parents = [log_parent]
1204             else:
1205                 parents = []
1206         else:
1207             # generate a new log entry
1208             replace = False
1209             msg = '%s\t%s' % (message, top.get_id_hash())
1210             if old_log:
1211                 parents = [old_log]
1212             else:
1213                 parents = []
1214
1215         if notes:
1216             msg += '\n\n' + notes
1217
1218         log = git.commit(message = msg, parents = parents,
1219                          cache_update = False, tree_id = top.get_tree(),
1220                          allowempty = True)
1221         patch.set_log(log)
1222
1223     def hide_patch(self, name):
1224         """Add the patch to the hidden list.
1225         """
1226         unapplied = self.get_unapplied()
1227         if name not in unapplied:
1228             # keep the checking order for backward compatibility with
1229             # the old hidden patches functionality
1230             if self.patch_applied(name):
1231                 raise StackException, 'Cannot hide applied patch "%s"' % name
1232             elif self.patch_hidden(name):
1233                 raise StackException, 'Patch "%s" already hidden' % name
1234             else:
1235                 raise StackException, 'Unknown patch "%s"' % name
1236
1237         if not self.patch_hidden(name):
1238             # check needed for backward compatibility with the old
1239             # hidden patches functionality
1240             append_string(self.__hidden_file, name)
1241
1242         unapplied.remove(name)
1243         write_strings(self.__unapplied_file, unapplied)
1244
1245     def unhide_patch(self, name):
1246         """Remove the patch from the hidden list.
1247         """
1248         hidden = self.get_hidden()
1249         if not name in hidden:
1250             if self.patch_applied(name) or self.patch_unapplied(name):
1251                 raise StackException, 'Patch "%s" not hidden' % name
1252             else:
1253                 raise StackException, 'Unknown patch "%s"' % name
1254
1255         hidden.remove(name)
1256         write_strings(self.__hidden_file, hidden)
1257
1258         if not self.patch_applied(name) and not self.patch_unapplied(name):
1259             # check needed for backward compatibility with the old
1260             # hidden patches functionality
1261             append_string(self.__unapplied_file, name)