chiark / gitweb /
5e9d4fb41a2a78ab04b98797e16a4faaf8498789
[stgit] / stgit / stack.py
1 """Basic quilt-like functionality
2 """
3
4 __copyright__ = """
5 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License version 2 as
9 published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 """
20
21 import sys, os, re
22 from email.Utils import formatdate
23
24 from stgit.utils import *
25 from stgit.out import *
26 from stgit import git, basedir, templates
27 from stgit.config import config
28 from shutil import copyfile
29
30
31 # stack exception class
32 class StackException(Exception):
33     pass
34
35 class FilterUntil:
36     def __init__(self):
37         self.should_print = True
38     def __call__(self, x, until_test, prefix):
39         if until_test(x):
40             self.should_print = False
41         if self.should_print:
42             return x[0:len(prefix)] != prefix
43         return False
44
45 #
46 # Functions
47 #
48 __comment_prefix = 'STG:'
49 __patch_prefix = 'STG_PATCH:'
50
51 def __clean_comments(f):
52     """Removes lines marked for status in a commit file
53     """
54     f.seek(0)
55
56     # remove status-prefixed lines
57     lines = f.readlines()
58
59     patch_filter = FilterUntil()
60     until_test = lambda t: t == (__patch_prefix + '\n')
61     lines = [l for l in lines if patch_filter(l, until_test, __comment_prefix)]
62
63     # remove empty lines at the end
64     while len(lines) != 0 and lines[-1] == '\n':
65         del lines[-1]
66
67     f.seek(0); f.truncate()
68     f.writelines(lines)
69
70 # TODO: move this out of the stgit.stack module, it is really for
71 # higher level commands to handle the user interaction
72 def edit_file(series, line, comment, show_patch = True):
73     fname = '.stgitmsg.txt'
74     tmpl = templates.get_template('patchdescr.tmpl')
75
76     f = file(fname, 'w+')
77     if line:
78         print >> f, line
79     elif tmpl:
80         print >> f, tmpl,
81     else:
82         print >> f
83     print >> f, __comment_prefix, comment
84     print >> f, __comment_prefix, \
85           'Lines prefixed with "%s" will be automatically removed.' \
86           % __comment_prefix
87     print >> f, __comment_prefix, \
88           'Trailing empty lines will be automatically removed.'
89
90     if show_patch:
91        print >> f, __patch_prefix
92        # series.get_patch(series.get_current()).get_top()
93        diff_str = git.diff(rev1 = series.get_patch(series.get_current()).get_bottom())
94        f.write(diff_str)
95
96     #Vim modeline must be near the end.
97     print >> f, __comment_prefix, 'vi: set textwidth=75 filetype=diff nobackup:'
98     f.close()
99
100     call_editor(fname)
101
102     f = file(fname, 'r+')
103
104     __clean_comments(f)
105     f.seek(0)
106     result = f.read()
107
108     f.close()
109     os.remove(fname)
110
111     return result
112
113 #
114 # Classes
115 #
116
117 class StgitObject:
118     """An object with stgit-like properties stored as files in a directory
119     """
120     def _set_dir(self, dir):
121         self.__dir = dir
122     def _dir(self):
123         return self.__dir
124
125     def create_empty_field(self, name):
126         create_empty_file(os.path.join(self.__dir, name))
127
128     def _get_field(self, name, multiline = False):
129         id_file = os.path.join(self.__dir, name)
130         if os.path.isfile(id_file):
131             line = read_string(id_file, multiline)
132             if line == '':
133                 return None
134             else:
135                 return line
136         else:
137             return None
138
139     def _set_field(self, name, value, multiline = False):
140         fname = os.path.join(self.__dir, name)
141         if value and value != '':
142             write_string(fname, value, multiline)
143         elif os.path.isfile(fname):
144             os.remove(fname)
145
146     
147 class Patch(StgitObject):
148     """Basic patch implementation
149     """
150     def __init_refs(self):
151         self.__top_ref = self.__refs_base + '/' + self.__name
152         self.__log_ref = self.__top_ref + '.log'
153
154     def __init__(self, name, series_dir, refs_base):
155         self.__series_dir = series_dir
156         self.__name = name
157         self._set_dir(os.path.join(self.__series_dir, self.__name))
158         self.__refs_base = refs_base
159         self.__init_refs()
160
161     def create(self):
162         os.mkdir(self._dir())
163         self.create_empty_field('bottom')
164         self.create_empty_field('top')
165
166     def delete(self):
167         for f in os.listdir(self._dir()):
168             os.remove(os.path.join(self._dir(), f))
169         os.rmdir(self._dir())
170         git.delete_ref(self.__top_ref)
171         if git.ref_exists(self.__log_ref):
172             git.delete_ref(self.__log_ref)
173
174     def get_name(self):
175         return self.__name
176
177     def rename(self, newname):
178         olddir = self._dir()
179         old_top_ref = self.__top_ref
180         old_log_ref = self.__log_ref
181         self.__name = newname
182         self._set_dir(os.path.join(self.__series_dir, self.__name))
183         self.__init_refs()
184
185         git.rename_ref(old_top_ref, self.__top_ref)
186         if git.ref_exists(old_log_ref):
187             git.rename_ref(old_log_ref, self.__log_ref)
188         os.rename(olddir, self._dir())
189
190     def __update_top_ref(self, ref):
191         git.set_ref(self.__top_ref, ref)
192
193     def __update_log_ref(self, ref):
194         git.set_ref(self.__log_ref, ref)
195
196     def update_top_ref(self):
197         top = self.get_top()
198         if top:
199             self.__update_top_ref(top)
200
201     def get_old_bottom(self):
202         return self._get_field('bottom.old')
203
204     def get_bottom(self):
205         return self._get_field('bottom')
206
207     def set_bottom(self, value, backup = False):
208         if backup:
209             curr = self._get_field('bottom')
210             self._set_field('bottom.old', curr)
211         self._set_field('bottom', value)
212
213     def get_old_top(self):
214         return self._get_field('top.old')
215
216     def get_top(self):
217         return self._get_field('top')
218
219     def set_top(self, value, backup = False):
220         if backup:
221             curr = self._get_field('top')
222             self._set_field('top.old', curr)
223         self._set_field('top', value)
224         self.__update_top_ref(value)
225
226     def restore_old_boundaries(self):
227         bottom = self._get_field('bottom.old')
228         top = self._get_field('top.old')
229
230         if top and bottom:
231             self._set_field('bottom', bottom)
232             self._set_field('top', top)
233             self.__update_top_ref(top)
234             return True
235         else:
236             return False
237
238     def get_description(self):
239         return self._get_field('description', True)
240
241     def set_description(self, line):
242         self._set_field('description', line, True)
243
244     def get_authname(self):
245         return self._get_field('authname')
246
247     def set_authname(self, name):
248         self._set_field('authname', name or git.author().name)
249
250     def get_authemail(self):
251         return self._get_field('authemail')
252
253     def set_authemail(self, email):
254         self._set_field('authemail', email or git.author().email)
255
256     def get_authdate(self):
257         date = self._get_field('authdate')
258         if not date:
259             return date
260
261         if re.match('[0-9]+\s+[+-][0-9]+', date):
262             # Unix time (seconds) + time zone
263             secs_tz = date.split()
264             date = formatdate(int(secs_tz[0]))[:-5] + secs_tz[1]
265
266         return date
267
268     def set_authdate(self, date):
269         self._set_field('authdate', date or git.author().date)
270
271     def get_commname(self):
272         return self._get_field('commname')
273
274     def set_commname(self, name):
275         self._set_field('commname', name or git.committer().name)
276
277     def get_commemail(self):
278         return self._get_field('commemail')
279
280     def set_commemail(self, email):
281         self._set_field('commemail', email or git.committer().email)
282
283     def get_log(self):
284         return self._get_field('log')
285
286     def set_log(self, value, backup = False):
287         self._set_field('log', value)
288         self.__update_log_ref(value)
289
290 # The current StGIT metadata format version.
291 FORMAT_VERSION = 2
292
293 class PatchSet(StgitObject):
294     def __init__(self, name = None):
295         try:
296             if name:
297                 self.set_name (name)
298             else:
299                 self.set_name (git.get_head_file())
300             self.__base_dir = basedir.get()
301         except git.GitException, ex:
302             raise StackException, 'GIT tree not initialised: %s' % ex
303
304         self._set_dir(os.path.join(self.__base_dir, 'patches', self.get_name()))
305
306     def get_name(self):
307         return self.__name
308     def set_name(self, name):
309         self.__name = name
310
311     def _basedir(self):
312         return self.__base_dir
313
314     def get_head(self):
315         """Return the head of the branch
316         """
317         crt = self.get_current_patch()
318         if crt:
319             return crt.get_top()
320         else:
321             return self.get_base()
322
323     def get_protected(self):
324         return os.path.isfile(os.path.join(self._dir(), 'protected'))
325
326     def protect(self):
327         protect_file = os.path.join(self._dir(), 'protected')
328         if not os.path.isfile(protect_file):
329             create_empty_file(protect_file)
330
331     def unprotect(self):
332         protect_file = os.path.join(self._dir(), 'protected')
333         if os.path.isfile(protect_file):
334             os.remove(protect_file)
335
336     def __branch_descr(self):
337         return 'branch.%s.description' % self.get_name()
338
339     def get_description(self):
340         return config.get(self.__branch_descr()) or ''
341
342     def set_description(self, line):
343         if line:
344             config.set(self.__branch_descr(), line)
345         else:
346             config.unset(self.__branch_descr())
347
348     def head_top_equal(self):
349         """Return true if the head and the top are the same
350         """
351         crt = self.get_current_patch()
352         if not crt:
353             # we don't care, no patches applied
354             return True
355         return git.get_head() == crt.get_top()
356
357     def is_initialised(self):
358         """Checks if series is already initialised
359         """
360         return bool(config.get(self.format_version_key()))
361
362
363 def shortlog(patches):
364     log = ''.join(Run('git-log', '--pretty=short',
365                       p.get_top(), '^%s' % p.get_bottom()).raw_output()
366                   for p in patches)
367     return Run('git-shortlog').raw_input(log).raw_output()
368
369 class Series(PatchSet):
370     """Class including the operations on series
371     """
372     def __init__(self, name = None):
373         """Takes a series name as the parameter.
374         """
375         PatchSet.__init__(self, name)
376
377         # Update the branch to the latest format version if it is
378         # initialized, but don't touch it if it isn't.
379         self.update_to_current_format_version()
380
381         self.__refs_base = 'refs/patches/%s' % self.get_name()
382
383         self.__applied_file = os.path.join(self._dir(), 'applied')
384         self.__unapplied_file = os.path.join(self._dir(), 'unapplied')
385         self.__hidden_file = os.path.join(self._dir(), 'hidden')
386
387         # where this series keeps its patches
388         self.__patch_dir = os.path.join(self._dir(), 'patches')
389
390         # trash directory
391         self.__trash_dir = os.path.join(self._dir(), 'trash')
392
393     def format_version_key(self):
394         return 'branch.%s.stgit.stackformatversion' % self.get_name()
395
396     def update_to_current_format_version(self):
397         """Update a potentially older StGIT directory structure to the
398         latest version. Note: This function should depend as little as
399         possible on external functions that may change during a format
400         version bump, since it must remain able to process older formats."""
401
402         branch_dir = os.path.join(self._basedir(), 'patches', self.get_name())
403         def get_format_version():
404             """Return the integer format version number, or None if the
405             branch doesn't have any StGIT metadata at all, of any version."""
406             fv = config.get(self.format_version_key())
407             ofv = config.get('branch.%s.stgitformatversion' % self.get_name())
408             if fv:
409                 # Great, there's an explicitly recorded format version
410                 # number, which means that the branch is initialized and
411                 # of that exact version.
412                 return int(fv)
413             elif ofv:
414                 # Old name for the version info, upgrade it
415                 config.set(self.format_version_key(), ofv)
416                 config.unset('branch.%s.stgitformatversion' % self.get_name())
417                 return int(ofv)
418             elif os.path.isdir(os.path.join(branch_dir, 'patches')):
419                 # There's a .git/patches/<branch>/patches dirctory, which
420                 # means this is an initialized version 1 branch.
421                 return 1
422             elif os.path.isdir(branch_dir):
423                 # There's a .git/patches/<branch> directory, which means
424                 # this is an initialized version 0 branch.
425                 return 0
426             else:
427                 # The branch doesn't seem to be initialized at all.
428                 return None
429         def set_format_version(v):
430             out.info('Upgraded branch %s to format version %d' % (self.get_name(), v))
431             config.set(self.format_version_key(), '%d' % v)
432         def mkdir(d):
433             if not os.path.isdir(d):
434                 os.makedirs(d)
435         def rm(f):
436             if os.path.exists(f):
437                 os.remove(f)
438         def rm_ref(ref):
439             if git.ref_exists(ref):
440                 git.delete_ref(ref)
441
442         # Update 0 -> 1.
443         if get_format_version() == 0:
444             mkdir(os.path.join(branch_dir, 'trash'))
445             patch_dir = os.path.join(branch_dir, 'patches')
446             mkdir(patch_dir)
447             refs_base = 'refs/patches/%s' % self.get_name()
448             for patch in (file(os.path.join(branch_dir, 'unapplied')).readlines()
449                           + file(os.path.join(branch_dir, 'applied')).readlines()):
450                 patch = patch.strip()
451                 os.rename(os.path.join(branch_dir, patch),
452                           os.path.join(patch_dir, patch))
453                 Patch(patch, patch_dir, refs_base).update_top_ref()
454             set_format_version(1)
455
456         # Update 1 -> 2.
457         if get_format_version() == 1:
458             desc_file = os.path.join(branch_dir, 'description')
459             if os.path.isfile(desc_file):
460                 desc = read_string(desc_file)
461                 if desc:
462                     config.set('branch.%s.description' % self.get_name(), desc)
463                 rm(desc_file)
464             rm(os.path.join(branch_dir, 'current'))
465             rm_ref('refs/bases/%s' % self.get_name())
466             set_format_version(2)
467
468         # Make sure we're at the latest version.
469         if not get_format_version() in [None, FORMAT_VERSION]:
470             raise StackException('Branch %s is at format version %d, expected %d'
471                                  % (self.get_name(), get_format_version(), FORMAT_VERSION))
472
473     def __patch_name_valid(self, name):
474         """Raise an exception if the patch name is not valid.
475         """
476         if not name or re.search('[^\w.-]', name):
477             raise StackException, 'Invalid patch name: "%s"' % name
478
479     def get_patch(self, name):
480         """Return a Patch object for the given name
481         """
482         return Patch(name, self.__patch_dir, self.__refs_base)
483
484     def get_current_patch(self):
485         """Return a Patch object representing the topmost patch, or
486         None if there is no such patch."""
487         crt = self.get_current()
488         if not crt:
489             return None
490         return self.get_patch(crt)
491
492     def get_current(self):
493         """Return the name of the topmost patch, or None if there is
494         no such patch."""
495         try:
496             applied = self.get_applied()
497         except StackException:
498             # No "applied" file: branch is not initialized.
499             return None
500         try:
501             return applied[-1]
502         except IndexError:
503             # No patches applied.
504             return None
505
506     def get_applied(self):
507         if not os.path.isfile(self.__applied_file):
508             raise StackException, 'Branch "%s" not initialised' % self.get_name()
509         return read_strings(self.__applied_file)
510
511     def get_unapplied(self):
512         if not os.path.isfile(self.__unapplied_file):
513             raise StackException, 'Branch "%s" not initialised' % self.get_name()
514         return read_strings(self.__unapplied_file)
515
516     def get_hidden(self):
517         if not os.path.isfile(self.__hidden_file):
518             return []
519         return read_strings(self.__hidden_file)
520
521     def get_base(self):
522         # Return the parent of the bottommost patch, if there is one.
523         if os.path.isfile(self.__applied_file):
524             bottommost = file(self.__applied_file).readline().strip()
525             if bottommost:
526                 return self.get_patch(bottommost).get_bottom()
527         # No bottommost patch, so just return HEAD
528         return git.get_head()
529
530     def get_parent_remote(self):
531         value = config.get('branch.%s.remote' % self.get_name())
532         if value:
533             return value
534         elif 'origin' in git.remotes_list():
535             out.note(('No parent remote declared for stack "%s",'
536                       ' defaulting to "origin".' % self.get_name()),
537                      ('Consider setting "branch.%s.remote" and'
538                       ' "branch.%s.merge" with "git config".'
539                       % (self.get_name(), self.get_name())))
540             return 'origin'
541         else:
542             raise StackException, 'Cannot find a parent remote for "%s"' % self.get_name()
543
544     def __set_parent_remote(self, remote):
545         value = config.set('branch.%s.remote' % self.get_name(), remote)
546
547     def get_parent_branch(self):
548         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
549         if value:
550             return value
551         elif git.rev_parse('heads/origin'):
552             out.note(('No parent branch declared for stack "%s",'
553                       ' defaulting to "heads/origin".' % self.get_name()),
554                      ('Consider setting "branch.%s.stgit.parentbranch"'
555                       ' with "git config".' % self.get_name()))
556             return 'heads/origin'
557         else:
558             raise StackException, 'Cannot find a parent branch for "%s"' % self.get_name()
559
560     def __set_parent_branch(self, name):
561         if config.get('branch.%s.remote' % self.get_name()):
562             # Never set merge if remote is not set to avoid
563             # possibly-erroneous lookups into 'origin'
564             config.set('branch.%s.merge' % self.get_name(), name)
565         config.set('branch.%s.stgit.parentbranch' % self.get_name(), name)
566
567     def set_parent(self, remote, localbranch):
568         if localbranch:
569             if remote:
570                 self.__set_parent_remote(remote)
571             self.__set_parent_branch(localbranch)
572         # We'll enforce this later
573 #         else:
574 #             raise StackException, 'Parent branch (%s) should be specified for %s' % localbranch, self.get_name()
575
576     def __patch_is_current(self, patch):
577         return patch.get_name() == self.get_current()
578
579     def patch_applied(self, name):
580         """Return true if the patch exists in the applied list
581         """
582         return name in self.get_applied()
583
584     def patch_unapplied(self, name):
585         """Return true if the patch exists in the unapplied list
586         """
587         return name in self.get_unapplied()
588
589     def patch_hidden(self, name):
590         """Return true if the patch is hidden.
591         """
592         return name in self.get_hidden()
593
594     def patch_exists(self, name):
595         """Return true if there is a patch with the given name, false
596         otherwise."""
597         return self.patch_applied(name) or self.patch_unapplied(name) \
598                or self.patch_hidden(name)
599
600     def init(self, create_at=False, parent_remote=None, parent_branch=None):
601         """Initialises the stgit series
602         """
603         if self.is_initialised():
604             raise StackException, '%s already initialized' % self.get_name()
605         for d in [self._dir()]:
606             if os.path.exists(d):
607                 raise StackException, '%s already exists' % d
608
609         if (create_at!=False):
610             git.create_branch(self.get_name(), create_at)
611
612         os.makedirs(self.__patch_dir)
613
614         self.set_parent(parent_remote, parent_branch)
615
616         self.create_empty_field('applied')
617         self.create_empty_field('unapplied')
618         self._set_field('orig-base', git.get_head())
619
620         config.set(self.format_version_key(), str(FORMAT_VERSION))
621
622     def rename(self, to_name):
623         """Renames a series
624         """
625         to_stack = Series(to_name)
626
627         if to_stack.is_initialised():
628             raise StackException, '"%s" already exists' % to_stack.get_name()
629
630         patches = self.get_applied() + self.get_unapplied()
631
632         git.rename_branch(self.get_name(), to_name)
633
634         for patch in patches:
635             git.rename_ref('refs/patches/%s/%s' % (self.get_name(), patch),
636                            'refs/patches/%s/%s' % (to_name, patch))
637             git.rename_ref('refs/patches/%s/%s.log' % (self.get_name(), patch),
638                            'refs/patches/%s/%s.log' % (to_name, patch))
639         if os.path.isdir(self._dir()):
640             rename(os.path.join(self._basedir(), 'patches'),
641                    self.get_name(), to_stack.get_name())
642
643         # Rename the config section
644         for k in ['branch.%s', 'branch.%s.stgit']:
645             config.rename_section(k % self.get_name(), k % to_name)
646
647         self.__init__(to_name)
648
649     def clone(self, target_series):
650         """Clones a series
651         """
652         try:
653             # allow cloning of branches not under StGIT control
654             base = self.get_base()
655         except:
656             base = git.get_head()
657         Series(target_series).init(create_at = base)
658         new_series = Series(target_series)
659
660         # generate an artificial description file
661         new_series.set_description('clone of "%s"' % self.get_name())
662
663         # clone self's entire series as unapplied patches
664         try:
665             # allow cloning of branches not under StGIT control
666             applied = self.get_applied()
667             unapplied = self.get_unapplied()
668             patches = applied + unapplied
669             patches.reverse()
670         except:
671             patches = applied = unapplied = []
672         for p in patches:
673             patch = self.get_patch(p)
674             newpatch = new_series.new_patch(p, message = patch.get_description(),
675                                             can_edit = False, unapplied = True,
676                                             bottom = patch.get_bottom(),
677                                             top = patch.get_top(),
678                                             author_name = patch.get_authname(),
679                                             author_email = patch.get_authemail(),
680                                             author_date = patch.get_authdate())
681             if patch.get_log():
682                 out.info('Setting log to %s' %  patch.get_log())
683                 newpatch.set_log(patch.get_log())
684             else:
685                 out.info('No log for %s' % p)
686
687         # fast forward the cloned series to self's top
688         new_series.forward_patches(applied)
689
690         # Clone parent informations
691         value = config.get('branch.%s.remote' % self.get_name())
692         if value:
693             config.set('branch.%s.remote' % target_series, value)
694
695         value = config.get('branch.%s.merge' % self.get_name())
696         if value:
697             config.set('branch.%s.merge' % target_series, value)
698
699         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
700         if value:
701             config.set('branch.%s.stgit.parentbranch' % target_series, value)
702
703     def delete(self, force = False):
704         """Deletes an stgit series
705         """
706         if self.is_initialised():
707             patches = self.get_unapplied() + self.get_applied()
708             if not force and patches:
709                 raise StackException, \
710                       'Cannot delete: the series still contains patches'
711             for p in patches:
712                 self.get_patch(p).delete()
713
714             # remove the trash directory if any
715             if os.path.exists(self.__trash_dir):
716                 for fname in os.listdir(self.__trash_dir):
717                     os.remove(os.path.join(self.__trash_dir, fname))
718                 os.rmdir(self.__trash_dir)
719
720             # FIXME: find a way to get rid of those manual removals
721             # (move functionality to StgitObject ?)
722             if os.path.exists(self.__applied_file):
723                 os.remove(self.__applied_file)
724             if os.path.exists(self.__unapplied_file):
725                 os.remove(self.__unapplied_file)
726             if os.path.exists(self.__hidden_file):
727                 os.remove(self.__hidden_file)
728             if os.path.exists(self._dir()+'/orig-base'):
729                 os.remove(self._dir()+'/orig-base')
730
731             if not os.listdir(self.__patch_dir):
732                 os.rmdir(self.__patch_dir)
733             else:
734                 out.warn('Patch directory %s is not empty' % self.__patch_dir)
735
736             try:
737                 os.removedirs(self._dir())
738             except OSError:
739                 raise StackException('Series directory %s is not empty'
740                                      % self._dir())
741
742             try:
743                 git.delete_branch(self.get_name())
744             except GitException:
745                 out.warn('Could not delete branch "%s"' % self.get_name())
746
747         config.remove_section('branch.%s' % self.get_name())
748         config.remove_section('branch.%s.stgit' % self.get_name())
749
750     def refresh_patch(self, files = None, message = None, edit = False,
751                       show_patch = False,
752                       cache_update = True,
753                       author_name = None, author_email = None,
754                       author_date = None,
755                       committer_name = None, committer_email = None,
756                       backup = False, sign_str = None, log = 'refresh',
757                       notes = None, bottom = None):
758         """Generates a new commit for the topmost patch
759         """
760         patch = self.get_current_patch()
761         if not patch:
762             raise StackException, 'No patches applied'
763
764         descr = patch.get_description()
765         if not (message or descr):
766             edit = True
767             descr = ''
768         elif message:
769             descr = message
770
771         # TODO: move this out of the stgit.stack module, it is really
772         # for higher level commands to handle the user interaction
773         if not message and edit:
774             descr = edit_file(self, descr.rstrip(), \
775                               'Please edit the description for patch "%s" ' \
776                               'above.' % patch.get_name(), show_patch)
777
778         if not author_name:
779             author_name = patch.get_authname()
780         if not author_email:
781             author_email = patch.get_authemail()
782         if not author_date:
783             author_date = patch.get_authdate()
784         if not committer_name:
785             committer_name = patch.get_commname()
786         if not committer_email:
787             committer_email = patch.get_commemail()
788
789         descr = add_sign_line(descr, sign_str, committer_name, committer_email)
790
791         if not bottom:
792             bottom = patch.get_bottom()
793
794         commit_id = git.commit(files = files,
795                                message = descr, parents = [bottom],
796                                cache_update = cache_update,
797                                allowempty = True,
798                                author_name = author_name,
799                                author_email = author_email,
800                                author_date = author_date,
801                                committer_name = committer_name,
802                                committer_email = committer_email)
803
804         patch.set_bottom(bottom, backup = backup)
805         patch.set_top(commit_id, backup = backup)
806         patch.set_description(descr)
807         patch.set_authname(author_name)
808         patch.set_authemail(author_email)
809         patch.set_authdate(author_date)
810         patch.set_commname(committer_name)
811         patch.set_commemail(committer_email)
812
813         if log:
814             self.log_patch(patch, log, notes)
815
816         return commit_id
817
818     def undo_refresh(self):
819         """Undo the patch boundaries changes caused by 'refresh'
820         """
821         name = self.get_current()
822         assert(name)
823
824         patch = self.get_patch(name)
825         old_bottom = patch.get_old_bottom()
826         old_top = patch.get_old_top()
827
828         # the bottom of the patch is not changed by refresh. If the
829         # old_bottom is different, there wasn't any previous 'refresh'
830         # command (probably only a 'push')
831         if old_bottom != patch.get_bottom() or old_top == patch.get_top():
832             raise StackException, 'No undo information available'
833
834         git.reset(tree_id = old_top, check_out = False)
835         if patch.restore_old_boundaries():
836             self.log_patch(patch, 'undo')
837
838     def new_patch(self, name, message = None, can_edit = True,
839                   unapplied = False, show_patch = False,
840                   top = None, bottom = None, commit = True,
841                   author_name = None, author_email = None, author_date = None,
842                   committer_name = None, committer_email = None,
843                   before_existing = False):
844         """Creates a new patch, either pointing to an existing commit object,
845         or by creating a new commit object.
846         """
847
848         assert commit or (top and bottom)
849         assert not before_existing or (top and bottom)
850         assert not (commit and before_existing)
851         assert (top and bottom) or (not top and not bottom)
852         assert not top or (bottom == git.get_commit(top).get_parent())
853
854         if name != None:
855             self.__patch_name_valid(name)
856             if self.patch_exists(name):
857                 raise StackException, 'Patch "%s" already exists' % name
858
859         # TODO: move this out of the stgit.stack module, it is really
860         # for higher level commands to handle the user interaction
861         if not message and can_edit:
862             descr = edit_file(
863                 self, None,
864                 'Please enter the description for the patch above.',
865                 show_patch)
866         else:
867             descr = message
868
869         head = git.get_head()
870
871         if name == None:
872             name = make_patch_name(descr, self.patch_exists)
873
874         patch = self.get_patch(name)
875         patch.create()
876
877         if not bottom:
878             bottom = head
879         if not top:
880             top = head
881
882         patch.set_bottom(bottom)
883         patch.set_top(top)
884         patch.set_description(descr)
885         patch.set_authname(author_name)
886         patch.set_authemail(author_email)
887         patch.set_authdate(author_date)
888         patch.set_commname(committer_name)
889         patch.set_commemail(committer_email)
890
891         if before_existing:
892             insert_string(self.__applied_file, patch.get_name())
893         elif unapplied:
894             patches = [patch.get_name()] + self.get_unapplied()
895             write_strings(self.__unapplied_file, patches)
896             set_head = False
897         else:
898             append_string(self.__applied_file, patch.get_name())
899             set_head = True
900
901         if commit:
902             # create a commit for the patch (may be empty if top == bottom);
903             # only commit on top of the current branch
904             assert(unapplied or bottom == head)
905             top_commit = git.get_commit(top)
906             commit_id = git.commit(message = descr, parents = [bottom],
907                                    cache_update = False,
908                                    tree_id = top_commit.get_tree(),
909                                    allowempty = True, set_head = set_head,
910                                    author_name = author_name,
911                                    author_email = author_email,
912                                    author_date = author_date,
913                                    committer_name = committer_name,
914                                    committer_email = committer_email)
915             # set the patch top to the new commit
916             patch.set_top(commit_id)
917         else:
918             assert top != bottom
919
920         self.log_patch(patch, 'new')
921
922         return patch
923
924     def delete_patch(self, name):
925         """Deletes a patch
926         """
927         self.__patch_name_valid(name)
928         patch = self.get_patch(name)
929
930         if self.__patch_is_current(patch):
931             self.pop_patch(name)
932         elif self.patch_applied(name):
933             raise StackException, 'Cannot remove an applied patch, "%s", ' \
934                   'which is not current' % name
935         elif not name in self.get_unapplied():
936             raise StackException, 'Unknown patch "%s"' % name
937
938         # save the commit id to a trash file
939         write_string(os.path.join(self.__trash_dir, name), patch.get_top())
940
941         patch.delete()
942
943         unapplied = self.get_unapplied()
944         unapplied.remove(name)
945         write_strings(self.__unapplied_file, unapplied)
946
947     def forward_patches(self, names):
948         """Try to fast-forward an array of patches.
949
950         On return, patches in names[0:returned_value] have been pushed on the
951         stack. Apply the rest with push_patch
952         """
953         unapplied = self.get_unapplied()
954
955         forwarded = 0
956         top = git.get_head()
957
958         for name in names:
959             assert(name in unapplied)
960
961             patch = self.get_patch(name)
962
963             head = top
964             bottom = patch.get_bottom()
965             top = patch.get_top()
966
967             # top != bottom always since we have a commit for each patch
968             if head == bottom:
969                 # reset the backup information. No logging since the
970                 # patch hasn't changed
971                 patch.set_bottom(head, backup = True)
972                 patch.set_top(top, backup = True)
973
974             else:
975                 head_tree = git.get_commit(head).get_tree()
976                 bottom_tree = git.get_commit(bottom).get_tree()
977                 if head_tree == bottom_tree:
978                     # We must just reparent this patch and create a new commit
979                     # for it
980                     descr = patch.get_description()
981                     author_name = patch.get_authname()
982                     author_email = patch.get_authemail()
983                     author_date = patch.get_authdate()
984                     committer_name = patch.get_commname()
985                     committer_email = patch.get_commemail()
986
987                     top_tree = git.get_commit(top).get_tree()
988
989                     top = git.commit(message = descr, parents = [head],
990                                      cache_update = False,
991                                      tree_id = top_tree,
992                                      allowempty = True,
993                                      author_name = author_name,
994                                      author_email = author_email,
995                                      author_date = author_date,
996                                      committer_name = committer_name,
997                                      committer_email = committer_email)
998
999                     patch.set_bottom(head, backup = True)
1000                     patch.set_top(top, backup = True)
1001
1002                     self.log_patch(patch, 'push(f)')
1003                 else:
1004                     top = head
1005                     # stop the fast-forwarding, must do a real merge
1006                     break
1007
1008             forwarded+=1
1009             unapplied.remove(name)
1010
1011         if forwarded == 0:
1012             return 0
1013
1014         git.switch(top)
1015
1016         append_strings(self.__applied_file, names[0:forwarded])
1017         write_strings(self.__unapplied_file, unapplied)
1018
1019         return forwarded
1020
1021     def merged_patches(self, names):
1022         """Test which patches were merged upstream by reverse-applying
1023         them in reverse order. The function returns the list of
1024         patches detected to have been applied. The state of the tree
1025         is restored to the original one
1026         """
1027         patches = [self.get_patch(name) for name in names]
1028         patches.reverse()
1029
1030         merged = []
1031         for p in patches:
1032             if git.apply_diff(p.get_top(), p.get_bottom()):
1033                 merged.append(p.get_name())
1034         merged.reverse()
1035
1036         git.reset()
1037
1038         return merged
1039
1040     def push_empty_patch(self, name):
1041         """Pushes an empty patch on the stack
1042         """
1043         unapplied = self.get_unapplied()
1044         assert(name in unapplied)
1045
1046         # patch = self.get_patch(name)
1047         head = git.get_head()
1048
1049         append_string(self.__applied_file, name)
1050
1051         unapplied.remove(name)
1052         write_strings(self.__unapplied_file, unapplied)
1053
1054         self.refresh_patch(bottom = head, cache_update = False, log = 'push(m)')
1055
1056     def push_patch(self, name):
1057         """Pushes a patch on the stack
1058         """
1059         unapplied = self.get_unapplied()
1060         assert(name in unapplied)
1061
1062         patch = self.get_patch(name)
1063
1064         head = git.get_head()
1065         bottom = patch.get_bottom()
1066         top = patch.get_top()
1067         # top != bottom always since we have a commit for each patch
1068
1069         if head == bottom:
1070             # A fast-forward push. Just reset the backup
1071             # information. No need for logging
1072             patch.set_bottom(bottom, backup = True)
1073             patch.set_top(top, backup = True)
1074
1075             git.switch(top)
1076             append_string(self.__applied_file, name)
1077
1078             unapplied.remove(name)
1079             write_strings(self.__unapplied_file, unapplied)
1080             return False
1081
1082         # Need to create a new commit an merge in the old patch
1083         ex = None
1084         modified = False
1085
1086         # Try the fast applying first. If this fails, fall back to the
1087         # three-way merge
1088         if not git.apply_diff(bottom, top):
1089             # if git.apply_diff() fails, the patch requires a diff3
1090             # merge and can be reported as modified
1091             modified = True
1092
1093             # merge can fail but the patch needs to be pushed
1094             try:
1095                 git.merge(bottom, head, top, recursive = True)
1096             except git.GitException, ex:
1097                 out.error('The merge failed during "push".',
1098                           'Use "refresh" after fixing the conflicts or'
1099                           ' revert the operation with "push --undo".')
1100
1101         append_string(self.__applied_file, name)
1102
1103         unapplied.remove(name)
1104         write_strings(self.__unapplied_file, unapplied)
1105
1106         if not ex:
1107             # if the merge was OK and no conflicts, just refresh the patch
1108             # The GIT cache was already updated by the merge operation
1109             if modified:
1110                 log = 'push(m)'
1111             else:
1112                 log = 'push'
1113             self.refresh_patch(bottom = head, cache_update = False, log = log)
1114         else:
1115             # we store the correctly merged files only for
1116             # tracking the conflict history. Note that the
1117             # git.merge() operations should always leave the index
1118             # in a valid state (i.e. only stage 0 files)
1119             self.refresh_patch(bottom = head, cache_update = False,
1120                                log = 'push(c)')
1121             raise StackException, str(ex)
1122
1123         return modified
1124
1125     def undo_push(self):
1126         name = self.get_current()
1127         assert(name)
1128
1129         patch = self.get_patch(name)
1130         old_bottom = patch.get_old_bottom()
1131         old_top = patch.get_old_top()
1132
1133         # the top of the patch is changed by a push operation only
1134         # together with the bottom (otherwise the top was probably
1135         # modified by 'refresh'). If they are both unchanged, there
1136         # was a fast forward
1137         if old_bottom == patch.get_bottom() and old_top != patch.get_top():
1138             raise StackException, 'No undo information available'
1139
1140         git.reset()
1141         self.pop_patch(name)
1142         ret = patch.restore_old_boundaries()
1143         if ret:
1144             self.log_patch(patch, 'undo')
1145
1146         return ret
1147
1148     def pop_patch(self, name, keep = False):
1149         """Pops the top patch from the stack
1150         """
1151         applied = self.get_applied()
1152         applied.reverse()
1153         assert(name in applied)
1154
1155         patch = self.get_patch(name)
1156
1157         if git.get_head_file() == self.get_name():
1158             if keep and not git.apply_diff(git.get_head(), patch.get_bottom()):
1159                 raise StackException(
1160                     'Failed to pop patches while preserving the local changes')
1161             git.switch(patch.get_bottom(), keep)
1162         else:
1163             git.set_branch(self.get_name(), patch.get_bottom())
1164
1165         # save the new applied list
1166         idx = applied.index(name) + 1
1167
1168         popped = applied[:idx]
1169         popped.reverse()
1170         unapplied = popped + self.get_unapplied()
1171         write_strings(self.__unapplied_file, unapplied)
1172
1173         del applied[:idx]
1174         applied.reverse()
1175         write_strings(self.__applied_file, applied)
1176
1177     def empty_patch(self, name):
1178         """Returns True if the patch is empty
1179         """
1180         self.__patch_name_valid(name)
1181         patch = self.get_patch(name)
1182         bottom = patch.get_bottom()
1183         top = patch.get_top()
1184
1185         if bottom == top:
1186             return True
1187         elif git.get_commit(top).get_tree() \
1188                  == git.get_commit(bottom).get_tree():
1189             return True
1190
1191         return False
1192
1193     def rename_patch(self, oldname, newname):
1194         self.__patch_name_valid(newname)
1195
1196         applied = self.get_applied()
1197         unapplied = self.get_unapplied()
1198
1199         if oldname == newname:
1200             raise StackException, '"To" name and "from" name are the same'
1201
1202         if newname in applied or newname in unapplied:
1203             raise StackException, 'Patch "%s" already exists' % newname
1204
1205         if oldname in unapplied:
1206             self.get_patch(oldname).rename(newname)
1207             unapplied[unapplied.index(oldname)] = newname
1208             write_strings(self.__unapplied_file, unapplied)
1209         elif oldname in applied:
1210             self.get_patch(oldname).rename(newname)
1211
1212             applied[applied.index(oldname)] = newname
1213             write_strings(self.__applied_file, applied)
1214         else:
1215             raise StackException, 'Unknown patch "%s"' % oldname
1216
1217     def log_patch(self, patch, message, notes = None):
1218         """Generate a log commit for a patch
1219         """
1220         top = git.get_commit(patch.get_top())
1221         old_log = patch.get_log()
1222
1223         if message is None:
1224             # replace the current log entry
1225             if not old_log:
1226                 raise StackException, \
1227                       'No log entry to annotate for patch "%s"' \
1228                       % patch.get_name()
1229             replace = True
1230             log_commit = git.get_commit(old_log)
1231             msg = log_commit.get_log().split('\n')[0]
1232             log_parent = log_commit.get_parent()
1233             if log_parent:
1234                 parents = [log_parent]
1235             else:
1236                 parents = []
1237         else:
1238             # generate a new log entry
1239             replace = False
1240             msg = '%s\t%s' % (message, top.get_id_hash())
1241             if old_log:
1242                 parents = [old_log]
1243             else:
1244                 parents = []
1245
1246         if notes:
1247             msg += '\n\n' + notes
1248
1249         log = git.commit(message = msg, parents = parents,
1250                          cache_update = False, tree_id = top.get_tree(),
1251                          allowempty = True)
1252         patch.set_log(log)
1253
1254     def hide_patch(self, name):
1255         """Add the patch to the hidden list.
1256         """
1257         unapplied = self.get_unapplied()
1258         if name not in unapplied:
1259             # keep the checking order for backward compatibility with
1260             # the old hidden patches functionality
1261             if self.patch_applied(name):
1262                 raise StackException, 'Cannot hide applied patch "%s"' % name
1263             elif self.patch_hidden(name):
1264                 raise StackException, 'Patch "%s" already hidden' % name
1265             else:
1266                 raise StackException, 'Unknown patch "%s"' % name
1267
1268         if not self.patch_hidden(name):
1269             # check needed for backward compatibility with the old
1270             # hidden patches functionality
1271             append_string(self.__hidden_file, name)
1272
1273         unapplied.remove(name)
1274         write_strings(self.__unapplied_file, unapplied)
1275
1276     def unhide_patch(self, name):
1277         """Remove the patch from the hidden list.
1278         """
1279         hidden = self.get_hidden()
1280         if not name in hidden:
1281             if self.patch_applied(name) or self.patch_unapplied(name):
1282                 raise StackException, 'Patch "%s" not hidden' % name
1283             else:
1284                 raise StackException, 'Unknown patch "%s"' % name
1285
1286         hidden.remove(name)
1287         write_strings(self.__hidden_file, hidden)
1288
1289         if not self.patch_applied(name) and not self.patch_unapplied(name):
1290             # check needed for backward compatibility with the old
1291             # hidden patches functionality
1292             append_string(self.__unapplied_file, name)