chiark / gitweb /
Merge branch 'stable'
[stgit] / stgit / commands / imprt.py
index 045f18513213b4ccc4a7c7e8605fe0b0e4b05470..227743f51554f027832f143ef6b4c30e76809f72 100644 (file)
@@ -129,7 +129,7 @@ def __create_patch(filename, message, author_name, author_email,
         out.info('Ignoring already applied patch "%s"' % patch)
         return
     if options.replace and patch in crt_series.get_unapplied():
         out.info('Ignoring already applied patch "%s"' % patch)
         return
     if options.replace and patch in crt_series.get_unapplied():
-        crt_series.delete_patch(patch)
+        crt_series.delete_patch(patch, keep_log = True)
 
     # refresh_patch() will invoke the editor in this case, with correct
     # patch content
 
     # refresh_patch() will invoke the editor in this case, with correct
     # patch content
@@ -165,22 +165,52 @@ def __create_patch(filename, message, author_name, author_email,
     else:
         out.start('Importing patch "%s"' % patch)
         if options.base:
     else:
         out.start('Importing patch "%s"' % patch)
         if options.base:
-            git.apply_patch(diff = diff, base = git_id(options.base))
+            git.apply_patch(diff = diff,
+                            base = git_id(crt_series, options.base))
         else:
             git.apply_patch(diff = diff)
         crt_series.refresh_patch(edit = options.edit,
                                  show_patch = options.showpatch,
         else:
             git.apply_patch(diff = diff)
         crt_series.refresh_patch(edit = options.edit,
                                  show_patch = options.showpatch,
-                                 sign_str = options.sign_str)
+                                 sign_str = options.sign_str,
+                                 backup = False)
         out.done()
 
         out.done()
 
+def __mkpatchname(name, suffix):
+    if name.lower().endswith(suffix.lower()):
+        return name[:-len(suffix)]
+    return name
+
+def __get_handle_and_name(filename):
+    """Return a file object and a patch name derived from filename
+    """
+    # see if it's a gzip'ed or bzip2'ed patch
+    import bz2, gzip
+    for copen, ext in [(gzip.open, '.gz'), (bz2.BZ2File, '.bz2')]:
+        try:
+            f = copen(filename)
+            f.read(1)
+            f.seek(0)
+            return (f, __mkpatchname(filename, ext))
+        except IOError, e:
+            pass
+
+    # plain old file...
+    return (open(filename), filename)
+
 def __import_file(filename, options, patch = None):
     """Import a patch from a file or standard input
     """
 def __import_file(filename, options, patch = None):
     """Import a patch from a file or standard input
     """
+    pname = None
     if filename:
     if filename:
-        f = file(filename)
+        (f, pname) = __get_handle_and_name(filename)
     else:
         f = sys.stdin
 
     else:
         f = sys.stdin
 
+    if patch:
+        pname = patch
+    elif not pname:
+        pname = filename
+
     if options.mail:
         try:
             msg = email.message_from_file(f)
     if options.mail:
         try:
             msg = email.message_from_file(f)
@@ -190,16 +220,11 @@ def __import_file(filename, options, patch = None):
                  parse_mail(msg)
     else:
         message, author_name, author_email, author_date, diff = \
                  parse_mail(msg)
     else:
         message, author_name, author_email, author_date, diff = \
-                 parse_patch(f)
+                 parse_patch(f.read())
 
     if filename:
         f.close()
 
 
     if filename:
         f.close()
 
-    if patch:
-        pname = patch
-    else:
-        pname = filename
-
     __create_patch(pname, message, author_name, author_email,
                    author_date, diff, options)
 
     __create_patch(pname, message, author_name, author_email,
                    author_date, diff, options)
 
@@ -270,7 +295,7 @@ def func(parser, options, args):
 
     check_local_changes()
     check_conflicts()
 
     check_local_changes()
     check_conflicts()
-    check_head_top_equal()
+    check_head_top_equal(crt_series)
 
     if len(args) == 1:
         filename = args[0]
 
     if len(args) == 1:
         filename = args[0]
@@ -286,4 +311,4 @@ def func(parser, options, args):
     else:
         __import_file(filename, options)
 
     else:
         __import_file(filename, options)
 
-    print_crt_patch()
+    print_crt_patch(crt_series)