]> granicus.if.org Git - python/commitdiff
Raise an exception when src and dst refer to the same file via a hard link or a
authorJohannes Gijsbers <jlg@dds.nl>
Sat, 14 Aug 2004 13:30:02 +0000 (13:30 +0000)
committerJohannes Gijsbers <jlg@dds.nl>
Sat, 14 Aug 2004 13:30:02 +0000 (13:30 +0000)
symbolic link (bug #851123 / patch #854853, thanks Gregory Ball).

Lib/shutil.py
Lib/test/test_shutil.py

index fde8c90fe9fb596401ff0253135aa5280c4cba13..d361fa2b5d5caef66044b2653a4d4a8b80faab81 100644 (file)
@@ -24,16 +24,22 @@ def copyfileobj(fsrc, fdst, length=16*1024):
             break
         fdst.write(buf)
 
+def _samefile(src, dst):
+    # Macintosh, Unix.
+    if hasattr(os.path,'samefile'):
+        return os.path.samefile(src, dst)
+
+    # All other platforms: check for same pathname.
+    return (os.path.normcase(os.path.abspath(src)) ==
+            os.path.normcase(os.path.abspath(dst)))
 
 def copyfile(src, dst):
     """Copy data from src to dst"""
+    if _samefile(src, dst):
+        raise Error, "`%s` and `%s` are the same file" % (src, dst)
+
     fsrc = None
     fdst = None
-    # check for same pathname; all platforms
-    _src = os.path.normcase(os.path.abspath(src))
-    _dst = os.path.normcase(os.path.abspath(dst))
-    if _src == _dst:
-        return
     try:
         fsrc = open(src, 'rb')
         fdst = open(dst, 'wb')
index bcae72f1daca5a5b40561f61005a813bd9438c4f..083dbda706c9c8f63a8a708ee0f69f45a085e551 100644 (file)
@@ -6,6 +6,7 @@ import tempfile
 import os
 import os.path
 from test import test_support
+from test.test_support import TESTFN
 
 class TestShutil(unittest.TestCase):
     def test_rmtree_errors(self):
@@ -26,6 +27,26 @@ class TestShutil(unittest.TestCase):
             except:
                 pass
 
+    if hasattr(os, "symlink"):
+        def test_dont_copy_file_onto_link_to_itself(self):
+            # bug 851123.
+            os.mkdir(TESTFN)
+            src = os.path.join(TESTFN,'cheese')
+            dst = os.path.join(TESTFN,'shop')
+            try:
+                f = open(src,'w')
+                f.write('cheddar')
+                f.close()
+                for funcname in 'link','symlink':
+                    getattr(os, funcname)(src, dst)
+                    self.assertRaises(shutil.Error, shutil.copyfile, src, dst)
+                    self.assertEqual(open(src,'r').read(), 'cheddar')
+                    os.remove(dst)
+            finally:
+                try:
+                    shutil.rmtree(TESTFN)
+                except OSError:
+                    pass
 
 
 def test_main():