]> granicus.if.org Git - python/commitdiff
datetime.datetime and datetime.time can now be subclassed in Python. Brr.
authorTim Peters <tim.peters@gmail.com>
Sat, 17 May 2003 05:55:19 +0000 (05:55 +0000)
committerTim Peters <tim.peters@gmail.com>
Sat, 17 May 2003 05:55:19 +0000 (05:55 +0000)
Lib/test/test_datetime.py
Misc/NEWS
Modules/datetimemodule.c

index 51b5f4fce3baa56f9c708878c6ab16d84a681a63..c4978f3777768d867ba7a0e46b2db5c4bdca071a 100644 (file)
@@ -479,37 +479,6 @@ class TestDateOnly(unittest.TestCase):
         dt2 = dt - delta
         self.assertEqual(dt2, dt - days)
 
-    def test_subclass_date(self):
-
-        # XXX When datetime becomes usably subclassable, uncomment the
-        # XXX "self.theclass" lines and move this into TestDate.
-        # class C(self.theclass):
-        class C(date):
-            theAnswer = 42
-
-            def __new__(cls, *args, **kws):
-                temp = kws.copy()
-                extra = temp.pop('extra')
-                # result = self.theclass.__new__(cls, *args, **temp)
-                result = date.__new__(cls, *args, **temp)
-                result.extra = extra
-                return result
-
-            def newmeth(self, start):
-                return start + self.year + self.month
-
-        args = 2003, 4, 14
-
-        # dt1 = self.theclass(*args)
-        dt1 = date(*args)
-        dt2 = C(*args, **{'extra': 7})
-
-        self.assertEqual(dt2.__class__, C)
-        self.assertEqual(dt2.theAnswer, 42)
-        self.assertEqual(dt2.extra, 7)
-        self.assertEqual(dt1.toordinal(), dt2.toordinal())
-        self.assertEqual(dt2.newmeth(-7), dt1.year + dt1.month - 7)
-
 class TestDate(HarmlessMixedComparison):
     # Tests here should pass for both dates and datetimes, except for a
     # few tests that TestDateTime overrides.
@@ -1002,6 +971,32 @@ class TestDate(HarmlessMixedComparison):
         base = cls(2000, 2, 29)
         self.assertRaises(ValueError, base.replace, year=2001)
 
+    def test_subclass_date(self):
+
+        class C(self.theclass):
+            theAnswer = 42
+
+            def __new__(cls, *args, **kws):
+                temp = kws.copy()
+                extra = temp.pop('extra')
+                result = self.theclass.__new__(cls, *args, **temp)
+                result.extra = extra
+                return result
+
+            def newmeth(self, start):
+                return start + self.year + self.month
+
+        args = 2003, 4, 14
+
+        dt1 = self.theclass(*args)
+        dt2 = C(*args, **{'extra': 7})
+
+        self.assertEqual(dt2.__class__, C)
+        self.assertEqual(dt2.theAnswer, 42)
+        self.assertEqual(dt2.extra, 7)
+        self.assertEqual(dt1.toordinal(), dt2.toordinal())
+        self.assertEqual(dt2.newmeth(-7), dt1.year + dt1.month - 7)
+
 
 #############################################################################
 # datetime tests
@@ -1426,6 +1421,33 @@ class TestDateTime(TestDate):
         alsobog = AlsoBogus()
         self.assertRaises(ValueError, dt.astimezone, alsobog) # also naive
 
+    def test_subclass_datetime(self):
+
+        class C(self.theclass):
+            theAnswer = 42
+
+            def __new__(cls, *args, **kws):
+                temp = kws.copy()
+                extra = temp.pop('extra')
+                result = self.theclass.__new__(cls, *args, **temp)
+                result.extra = extra
+                return result
+
+            def newmeth(self, start):
+                return start + self.year + self.month + self.second
+
+        args = 2003, 4, 14, 12, 13, 41
+
+        dt1 = self.theclass(*args)
+        dt2 = C(*args, **{'extra': 7})
+
+        self.assertEqual(dt2.__class__, C)
+        self.assertEqual(dt2.theAnswer, 42)
+        self.assertEqual(dt2.extra, 7)
+        self.assertEqual(dt1.toordinal(), dt2.toordinal())
+        self.assertEqual(dt2.newmeth(-7), dt1.year + dt1.month +
+                                          dt1.second - 7)
+
 class TestTime(HarmlessMixedComparison):
 
     theclass = time
@@ -1660,6 +1682,32 @@ class TestTime(HarmlessMixedComparison):
         self.assertRaises(ValueError, base.replace, second=100)
         self.assertRaises(ValueError, base.replace, microsecond=1000000)
 
+    def test_subclass_time(self):
+
+        class C(self.theclass):
+            theAnswer = 42
+
+            def __new__(cls, *args, **kws):
+                temp = kws.copy()
+                extra = temp.pop('extra')
+                result = self.theclass.__new__(cls, *args, **temp)
+                result.extra = extra
+                return result
+
+            def newmeth(self, start):
+                return start + self.hour + self.second
+
+        args = 4, 5, 6
+
+        dt1 = self.theclass(*args)
+        dt2 = C(*args, **{'extra': 7})
+
+        self.assertEqual(dt2.__class__, C)
+        self.assertEqual(dt2.theAnswer, 42)
+        self.assertEqual(dt2.extra, 7)
+        self.assertEqual(dt1.isoformat(), dt2.isoformat())
+        self.assertEqual(dt2.newmeth(-7), dt1.hour + dt1.second - 7)
+
 # A mixin for classes with a tzinfo= argument.  Subclasses must define
 # theclass as a class atribute, and theclass(1, 1, 1, tzinfo=whatever)
 # must be legit (which is true for time and datetime).
@@ -2042,6 +2090,32 @@ class TestTimeTZ(TestTime, TZInfoBase):
         t2 = t2.replace(tzinfo=Varies())
         self.failUnless(t1 < t2)  # t1's offset counter still going up
 
+    def test_subclass_timetz(self):
+
+        class C(self.theclass):
+            theAnswer = 42
+
+            def __new__(cls, *args, **kws):
+                temp = kws.copy()
+                extra = temp.pop('extra')
+                result = self.theclass.__new__(cls, *args, **temp)
+                result.extra = extra
+                return result
+
+            def newmeth(self, start):
+                return start + self.hour + self.second
+
+        args = 4, 5, 6, 500, FixedOffset(-300, "EST", 1)
+
+        dt1 = self.theclass(*args)
+        dt2 = C(*args, **{'extra': 7})
+
+        self.assertEqual(dt2.__class__, C)
+        self.assertEqual(dt2.theAnswer, 42)
+        self.assertEqual(dt2.extra, 7)
+        self.assertEqual(dt1.utcoffset(), dt2.utcoffset())
+        self.assertEqual(dt2.newmeth(-7), dt1.hour + dt1.second - 7)
+
 
 # Testing datetime objects with a non-None tzinfo.
 
@@ -2625,6 +2699,32 @@ class TestDateTimeTZ(TestDateTime, TZInfoBase):
         t2 = t2.replace(tzinfo=Varies())
         self.failUnless(t1 < t2)  # t1's offset counter still going up
 
+    def test_subclass_datetimetz(self):
+
+        class C(self.theclass):
+            theAnswer = 42
+
+            def __new__(cls, *args, **kws):
+                temp = kws.copy()
+                extra = temp.pop('extra')
+                result = self.theclass.__new__(cls, *args, **temp)
+                result.extra = extra
+                return result
+
+            def newmeth(self, start):
+                return start + self.hour + self.year
+
+        args = 2002, 12, 31, 4, 5, 6, 500, FixedOffset(-300, "EST", 1)
+
+        dt1 = self.theclass(*args)
+        dt2 = C(*args, **{'extra': 7})
+
+        self.assertEqual(dt2.__class__, C)
+        self.assertEqual(dt2.theAnswer, 42)
+        self.assertEqual(dt2.extra, 7)
+        self.assertEqual(dt1.utcoffset(), dt2.utcoffset())
+        self.assertEqual(dt2.newmeth(-7), dt1.hour + dt1.year - 7)
+
 # Pain to set up DST-aware tzinfo classes.
 
 def first_sunday_on_or_after(dt):
index 6e8cae69bdb426dd106ae8d8ca11c0a141b51d7e..e182845729f1399e5399769750e1ffba08bbf20b 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -26,6 +26,9 @@ Core and builtins
 Extension modules
 -----------------
 
+- The datetime.datetime and datetime.time classes are now properly
+  subclassable.
+
 - _tkinter.{get|set}busywaitinterval was added.
 
 - itertools.islice() now accepts stop=None as documented.
index d25eb7048ba8a545795385039bbb14e848c7609c..36793382fe5c82b8145e5647d3f345c0bfd42820 100644 (file)
@@ -1235,38 +1235,42 @@ cmperror(PyObject *a, PyObject *b)
 }
 
 /* ---------------------------------------------------------------------------
- * Basic object allocation.  These allocate Python objects of the right
- * size and type, and do the Python object-initialization bit.  If there's
- * not enough memory, they return NULL after setting MemoryError.  All
- * data members remain uninitialized trash.
+ * Basic object allocation:  tp_alloc implementatiosn.  These allocate
+ * Python objects of the right size and type, and do the Python object-
+ * initialization bit.  If there's not enough memory, they return NULL after
+ * setting MemoryError.  All data members remain uninitialized trash.
+ *
+ * We abuse the tp_alloc "nitems" argument to communicate whether a tzinfo
+ * member is needed.  This is ugly.
  */
-static PyDateTime_Time *
-alloc_time(int aware)
+
+static PyObject *
+time_alloc(PyTypeObject *type, int aware)
 {
-       PyDateTime_Time *self;
+       PyObject *self;
 
-       self = (PyDateTime_Time *)
+       self = (PyObject *)
                PyObject_MALLOC(aware ?
                                sizeof(PyDateTime_Time) :
                                sizeof(_PyDateTime_BaseTime));
        if (self == NULL)
-               return (PyDateTime_Time *)PyErr_NoMemory();
-       PyObject_INIT(self, &PyDateTime_TimeType);
+               return (PyObject *)PyErr_NoMemory();
+       PyObject_INIT(self, type);
        return self;
 }
 
-static PyDateTime_DateTime *
-alloc_datetime(int aware)
+static PyObject *
+datetime_alloc(PyTypeObject *type, int aware)
 {
-       PyDateTime_DateTime *self;
+       PyObject *self;
 
-       self = (PyDateTime_DateTime *)
+       self = (PyObject *)
                PyObject_MALLOC(aware ?
                                sizeof(PyDateTime_DateTime) :
                                sizeof(_PyDateTime_BaseDateTime));
        if (self == NULL)
-               return (PyDateTime_DateTime *)PyErr_NoMemory();
-       PyObject_INIT(self, &PyDateTime_DateTimeType);
+               return (PyObject *)PyErr_NoMemory();
+       PyObject_INIT(self, type);
        return self;
 }
 
@@ -1302,17 +1306,17 @@ new_date_ex(int year, int month, int day, PyTypeObject *type)
 }
 
 #define new_date(year, month, day) \
-       (new_date_ex(year, month, day, &PyDateTime_DateType))
+       new_date_ex(year, month, day, &PyDateTime_DateType)
 
 /* Create a datetime instance with no range checking. */
 static PyObject *
-new_datetime(int year, int month, int day, int hour, int minute,
-            int second, int usecond, PyObject *tzinfo)
+new_datetime_ex(int year, int month, int day, int hour, int minute,
+            int second, int usecond, PyObject *tzinfo, PyTypeObject *type)
 {
        PyDateTime_DateTime *self;
        char aware = tzinfo != Py_None;
 
-       self = alloc_datetime(aware);
+       self = (PyDateTime_DateTime *) (type->tp_alloc(type, aware));
        if (self != NULL) {
                self->hastzinfo = aware;
                set_date_fields((PyDateTime_Date *)self, year, month, day);
@@ -1328,14 +1332,19 @@ new_datetime(int year, int month, int day, int hour, int minute,
        return (PyObject *)self;
 }
 
+#define new_datetime(y, m, d, hh, mm, ss, us, tzinfo)          \
+       new_datetime_ex(y, m, d, hh, mm, ss, us, tzinfo,        \
+                       &PyDateTime_DateTimeType)
+
 /* Create a time instance with no range checking. */
 static PyObject *
-new_time(int hour, int minute, int second, int usecond, PyObject *tzinfo)
+new_time_ex(int hour, int minute, int second, int usecond,
+           PyObject *tzinfo, PyTypeObject *type)
 {
        PyDateTime_Time *self;
        char aware = tzinfo != Py_None;
 
-       self = alloc_time(aware);
+       self = (PyDateTime_Time *) (type->tp_alloc(type, aware));
        if (self != NULL) {
                self->hastzinfo = aware;
                self->hashcode = -1;
@@ -1351,6 +1360,9 @@ new_time(int hour, int minute, int second, int usecond, PyObject *tzinfo)
        return (PyObject *)self;
 }
 
+#define new_time(hh, mm, ss, us, tzinfo)               \
+       new_time_ex(hh, mm, ss, us, tzinfo, &PyDateTime_TimeType)
+
 /* Create a timedelta instance.  Normalize the members iff normalize is
  * true.  Passing false is a speed optimization, if you know for sure
  * that seconds and microseconds are already in their proper ranges.  In any
@@ -3014,7 +3026,8 @@ time_new(PyTypeObject *type, PyObject *args, PyObject *kw)
                        }
                }
                aware = (char)(tzinfo != Py_None);
-               me = alloc_time(aware);
+               me = (PyDateTime_Time *) time_alloc(&PyDateTime_TimeType,
+                                                   aware);
                if (me != NULL) {
                        char *pdata = PyString_AS_STRING(state);
 
@@ -3036,7 +3049,8 @@ time_new(PyTypeObject *type, PyObject *args, PyObject *kw)
                        return NULL;
                if (check_tzinfo_subclass(tzinfo) < 0)
                        return NULL;
-               self = new_time(hour, minute, second, usecond, tzinfo);
+               self = new_time_ex(hour, minute, second, usecond, tzinfo,
+                                  type);
        }
        return self;
 }
@@ -3439,7 +3453,7 @@ statichere PyTypeObject PyDateTime_TimeType = {
        0,                                      /* tp_descr_set */
        0,                                      /* tp_dictoffset */
        0,                                      /* tp_init */
-       0,                                      /* tp_alloc */
+       time_alloc,                             /* tp_alloc */
        time_new,                               /* tp_new */
        0,                                      /* tp_free */
 };
@@ -3534,7 +3548,9 @@ datetime_new(PyTypeObject *type, PyObject *args, PyObject *kw)
                        }
                }
                aware = (char)(tzinfo != Py_None);
-               me = alloc_datetime(aware);
+               me = (PyDateTime_DateTime *) datetime_alloc(
+                                               &PyDateTime_DateTimeType,
+                                               aware);
                if (me != NULL) {
                        char *pdata = PyString_AS_STRING(state);
 
@@ -3558,9 +3574,9 @@ datetime_new(PyTypeObject *type, PyObject *args, PyObject *kw)
                        return NULL;
                if (check_tzinfo_subclass(tzinfo) < 0)
                        return NULL;
-               self = new_datetime(year, month, day,
-                                   hour, minute, second, usecond,
-                                   tzinfo);
+               self = new_datetime_ex(year, month, day,
+                                       hour, minute, second, usecond,
+                                       tzinfo, type);
        }
        return self;
 }
@@ -4460,7 +4476,7 @@ statichere PyTypeObject PyDateTime_DateTimeType = {
        0,                                      /* tp_descr_set */
        0,                                      /* tp_dictoffset */
        0,                                      /* tp_init */
-       0,                                      /* tp_alloc */
+       datetime_alloc,                         /* tp_alloc */
        datetime_new,                           /* tp_new */
        0,                                      /* tp_free */
 };