]> granicus.if.org Git - python/commitdiff
SF patch #662433: Fill arraymodule's tp_iter and sq_contains slots
authorRaymond Hettinger <python@rcn.com>
Tue, 7 Jan 2003 01:58:52 +0000 (01:58 +0000)
committerRaymond Hettinger <python@rcn.com>
Tue, 7 Jan 2003 01:58:52 +0000 (01:58 +0000)
Lib/test/test_array.py
Modules/arraymodule.c

index b650033292fc64574e1608db84a5725aaf1d582f..6aa3cf09312440822e097cab358f0f44d0161517 100755 (executable)
@@ -356,6 +356,12 @@ def testtype(type, example):
         a[2:3] = ins
         b[slice(2,3)] = ins
         c[2:3:] = ins
+        # iteration and contains
+        a = array.array(type, range(10))
+        vereq(list(a), range(10))
+        b = array.array(type, [20])
+        vereq(a[-1] in a, True)
+        vereq(b[0] not in a, True)
 
     # test that overflow exceptions are raised as expected for assignment
     # to array of specific integral types
index 03447cb9e382b4d4914e5f761baf7d90fae49518..e048d995a394b7331b3415330b281d2e61181703 100644 (file)
@@ -844,6 +844,19 @@ PyDoc_STRVAR(index_doc,
 \n\
 Return index of first occurence of x in the array.");
 
+static int
+array_contains(arrayobject *self, PyObject *v)
+{
+       int i, cmp;
+
+       for (i = 0, cmp = 0 ; cmp == 0 && i < self->ob_size; i++) {
+               PyObject *selfi = getarrayitem((PyObject *)self, i);
+               cmp = PyObject_RichCompareBool(selfi, v, Py_EQ);
+               Py_DECREF(selfi);
+       }
+       return cmp;
+}
+
 static PyObject *
 array_remove(arrayobject *self, PyObject *v)
 {
@@ -1655,7 +1668,7 @@ static PySequenceMethods array_as_sequence = {
        (intintargfunc)array_slice,             /*sq_slice*/
        (intobjargproc)array_ass_item,          /*sq_ass_item*/
        (intintobjargproc)array_ass_slice,      /*sq_ass_slice*/
-       NULL,                                   /*sq_contains*/
+       (objobjproc)array_contains,             /*sq_contains*/
        (binaryfunc)array_inplace_concat,       /*sq_inplace_concat*/
        (intargfunc)array_inplace_repeat        /*sq_inplace_repeat*/
 };
@@ -1822,6 +1835,8 @@ typecode -- the typecode character used to create the array\n\
 itemsize -- the length in bytes of one array item\n\
 ");
 
+static PyObject *array_iter(arrayobject *ao);
+
 static PyTypeObject Arraytype = {
        PyObject_HEAD_INIT(NULL)
        0,
@@ -1849,7 +1864,7 @@ static PyTypeObject Arraytype = {
        0,                                      /* tp_clear */
        array_richcompare,                      /* tp_richcompare */
        0,                                      /* tp_weaklistoffset */
-       0,                                      /* tp_iter */
+       (getiterfunc)array_iter,                /* tp_iter */
        0,                                      /* tp_iternext */
        array_methods,                          /* tp_methods */
        0,                                      /* tp_members */
@@ -1865,6 +1880,110 @@ static PyTypeObject Arraytype = {
        PyObject_Del,                           /* tp_free */
 };
 
+
+/*********************** Array Iterator **************************/
+
+typedef struct {
+       PyObject_HEAD
+       long                    index;
+       arrayobject             *ao;
+       PyObject                * (*getitem)(struct arrayobject *, int);
+} arrayiterobject;
+
+static PyTypeObject PyArrayIter_Type;
+
+#define PyArrayIter_Check(op) PyObject_TypeCheck(op, &PyArrayIter_Type)
+
+static PyObject *
+array_iter(arrayobject *ao)
+{
+       arrayiterobject *it;
+
+       if (!array_Check(ao)) {
+               PyErr_BadInternalCall();
+               return NULL;
+       }
+
+       it = PyObject_GC_New(arrayiterobject, &PyArrayIter_Type);
+       if (it == NULL)
+               return NULL;
+
+       Py_INCREF(ao);
+       it->ao = ao;
+       it->index = 0;
+       it->getitem = ao->ob_descr->getitem;
+       PyObject_GC_Track(it);
+       return (PyObject *)it;
+}
+
+static PyObject *
+arrayiter_getiter(PyObject *it)
+{
+       Py_INCREF(it);
+       return it;
+}
+
+static PyObject *
+arrayiter_next(arrayiterobject *it)
+{
+       assert(PyArrayIter_Check(it));
+       if (it->index < it->ao->ob_size)
+               return (*it->getitem)(it->ao, it->index++);
+       return NULL;
+}
+
+static void
+arrayiter_dealloc(arrayiterobject *it)
+{
+       PyObject_GC_UnTrack(it);
+       Py_XDECREF(it->ao);
+       PyObject_GC_Del(it);
+}
+
+static int
+arrayiter_traverse(arrayiterobject *it, visitproc visit, void *arg)
+{
+       if (it->ao != NULL)
+               return visit((PyObject *)(it->ao), arg);
+       return 0;
+}
+
+static PyTypeObject PyArrayIter_Type = {
+       PyObject_HEAD_INIT(&PyType_Type)
+       0,                                      /* ob_size */
+       "arrayiterator",                        /* tp_name */
+       sizeof(arrayiterobject),                /* tp_basicsize */
+       0,                                      /* tp_itemsize */
+       /* methods */
+       (destructor)arrayiter_dealloc,          /* tp_dealloc */
+       0,                                      /* tp_print */
+       0,                                      /* tp_getattr */
+       0,                                      /* tp_setattr */
+       0,                                      /* tp_compare */
+       0,                                      /* tp_repr */
+       0,                                      /* tp_as_number */
+       0,                                      /* tp_as_sequence */
+       0,                                      /* tp_as_mapping */
+       0,                                      /* tp_hash */
+       0,                                      /* tp_call */
+       0,                                      /* tp_str */
+       PyObject_GenericGetAttr,                /* tp_getattro */
+       0,                                      /* tp_setattro */
+       0,                                      /* tp_as_buffer */
+       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,/* tp_flags */
+       0,                                      /* tp_doc */
+       (traverseproc)arrayiter_traverse,       /* tp_traverse */
+       0,                                      /* tp_clear */
+       0,                                      /* tp_richcompare */
+       0,                                      /* tp_weaklistoffset */
+       (getiterfunc)arrayiter_getiter,         /* tp_iter */
+       (iternextfunc)arrayiter_next,           /* tp_iternext */
+       0,                                      /* tp_methods */
+};
+
+
+/*********************** Install Module **************************/
+
 /* No functions in array module. */
 static PyMethodDef a_methods[] = {
     {NULL, NULL, 0, NULL}        /* Sentinel */