]> granicus.if.org Git - python/commitdiff
bpo-31222: Make (datetime|date|time).replace return subclass type in Pure Python...
authorMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>
Fri, 10 Nov 2017 00:52:05 +0000 (16:52 -0800)
committerVictor Stinner <victor.stinner@gmail.com>
Fri, 10 Nov 2017 00:52:05 +0000 (16:52 -0800)
(cherry picked from commit 191e993365ac3206f46132dcf46236471ec54bfa)

Lib/datetime.py
Lib/test/datetimetester.py

index b95536fb7afc02b2b4f9691c2d2bec4363ced121..150664ea3b6570f04bc4df145ffca896f20373e1 100644 (file)
@@ -827,7 +827,7 @@ class date:
             month = self._month
         if day is None:
             day = self._day
-        return date(year, month, day)
+        return type(self)(year, month, day)
 
     # Comparisons of date objects with other.
 
@@ -1315,7 +1315,7 @@ class time:
             tzinfo = self.tzinfo
         if fold is None:
             fold = self._fold
-        return time(hour, minute, second, microsecond, tzinfo, fold=fold)
+        return type(self)(hour, minute, second, microsecond, tzinfo, fold=fold)
 
     # Pickle support.
 
@@ -1596,7 +1596,7 @@ class datetime(date):
             tzinfo = self.tzinfo
         if fold is None:
             fold = self.fold
-        return datetime(year, month, day, hour, minute, second,
+        return type(self)(year, month, day, hour, minute, second,
                           microsecond, tzinfo, fold=fold)
 
     def _local_timezone(self):
index c00e38cb0c05826220055623cea973d79c8d758a..f23a5305e45123291b708652a603eb9485aa246e 100644 (file)
@@ -1500,6 +1500,13 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase):
         base = cls(2000, 2, 29)
         self.assertRaises(ValueError, base.replace, year=2001)
 
+    def test_subclass_replace(self):
+        class DateSubclass(self.theclass):
+            pass
+
+        dt = DateSubclass(2012, 1, 1)
+        self.assertIs(type(dt.replace(year=2013)), DateSubclass)
+
     def test_subclass_date(self):
 
         class C(self.theclass):
@@ -2599,6 +2606,13 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase):
         self.assertRaises(ValueError, base.replace, second=100)
         self.assertRaises(ValueError, base.replace, microsecond=1000000)
 
+    def test_subclass_replace(self):
+        class TimeSubclass(self.theclass):
+            pass
+
+        ctime = TimeSubclass(12, 30)
+        self.assertIs(type(ctime.replace(hour=10)), TimeSubclass)
+
     def test_subclass_time(self):
 
         class C(self.theclass):