]> granicus.if.org Git - python/commitdiff
Generalize operator.indexOf (PySequence_Index) to work with any
authorTim Peters <tim.peters@gmail.com>
Sat, 8 Sep 2001 04:00:12 +0000 (04:00 +0000)
committerTim Peters <tim.peters@gmail.com>
Sat, 8 Sep 2001 04:00:12 +0000 (04:00 +0000)
iterable object.  I'm not sure how that got overlooked before!

Got rid of the internal _PySequence_IterContains, introduced a new
internal _PySequence_IterSearch, and rewrote all the iteration-based
"count of", "index of", and "is the object in it or not?" routines to
just call the new function.  I suppose it's slower this way, but the
code duplication was getting depressing.

Include/abstract.h
Lib/test/test_iter.py
Misc/NEWS
Objects/abstract.c
Objects/classobject.c
Objects/typeobject.c

index f4c1b3ec65b21e28d2fb0f71123f20b50fb44e7e..d736efcb9094cf2762496bfd3a61b5f5d21ed155 100644 (file)
@@ -988,14 +988,24 @@ xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx*/
      DL_IMPORT(int) PySequence_Contains(PyObject *seq, PyObject *ob);
        /*
          Return -1 if error; 1 if ob in seq; 0 if ob not in seq.
-         Use __contains__ if possible, else _PySequence_IterContains().
-       */
-
-     DL_IMPORT(int) _PySequence_IterContains(PyObject *seq, PyObject *ob);
-       /*
-         Return -1 if error; 1 if ob in seq; 0 if ob not in seq.
-         Always uses the iteration protocol, and only Py_EQ comparisons.
-       */
+         Use __contains__ if possible, else _PySequence_IterSearch().
+       */
+
+#define PY_ITERSEARCH_COUNT    1
+#define PY_ITERSEARCH_INDEX    2
+#define PY_ITERSEARCH_CONTAINS 3
+     DL_IMPORT(int) _PySequence_IterSearch(PyObject *seq, PyObject *obj,
+                   int operation);
+       /*
+         Iterate over seq.  Result depends on the operation:
+         PY_ITERSEARCH_COUNT:  return # of times obj appears in seq; -1 if
+               error.
+         PY_ITERSEARCH_INDEX:  return 0-based index of first occurence of
+               obj in seq; set ValueError and return -1 if none found;
+               also return -1 on error.
+         PY_ITERSEARCH_CONTAINS:  return 1 if obj in seq, else 0; -1 on
+               error.
+       */
 
 /* For DLL-level backwards compatibility */
 #undef PySequence_In
index 8b6891b9cedfdccfee2395d42c654cef6b754ac8..37fab7ced7680972e58f1626721a9e0f36281dad 100644 (file)
@@ -600,6 +600,47 @@ class TestCase(unittest.TestCase):
             except OSError:
                 pass
 
+    # Test iterators with operator.indexOf (PySequence_Index).
+    def test_indexOf(self):
+        from operator import indexOf
+        self.assertEqual(indexOf([1,2,2,3,2,5], 1), 0)
+        self.assertEqual(indexOf((1,2,2,3,2,5), 2), 1)
+        self.assertEqual(indexOf((1,2,2,3,2,5), 3), 3)
+        self.assertEqual(indexOf((1,2,2,3,2,5), 5), 5)
+        self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 0)
+        self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 6)
+
+        self.assertEqual(indexOf("122325", "2"), 1)
+        self.assertEqual(indexOf("122325", "5"), 5)
+        self.assertRaises(ValueError, indexOf, "122325", "6")
+
+        self.assertRaises(TypeError, indexOf, 42, 1)
+        self.assertRaises(TypeError, indexOf, indexOf, indexOf)
+
+        f = open(TESTFN, "w")
+        try:
+            f.write("a\n" "b\n" "c\n" "d\n" "e\n")
+        finally:
+            f.close()
+        f = open(TESTFN, "r")
+        try:
+            fiter = iter(f)
+            self.assertEqual(indexOf(fiter, "b\n"), 1)
+            self.assertEqual(indexOf(fiter, "d\n"), 1)
+            self.assertEqual(indexOf(fiter, "e\n"), 0)
+            self.assertRaises(ValueError, indexOf, fiter, "a\n")
+        finally:
+            f.close()
+            try:
+                unlink(TESTFN)
+            except OSError:
+                pass
+
+        iclass = IteratingSequenceClass(3)
+        for i in range(3):
+            self.assertEqual(indexOf(iclass, i), i)
+        self.assertRaises(ValueError, indexOf, iclass, -1)
+
     # Test iterators on RHS of unpacking assignments.
     def test_unpack_iter(self):
         a, b = 1, 2
index 87bf717954fef911a3bc219717933ea6d7b3d057..ecc4588762784d1980346299e5623ca07dce3c80 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -5,6 +5,9 @@ Core
 
 Library
 
+- operator.indexOf() (PySequence_Index() in the C API) now works with any
+  iterable object.
+
 Tools
 
 Build
index c3a397c7c0d47d428b985e74cf45f21b6892f317..5361b1d5c1b4551f8d1da204876db899bd94999b 100644 (file)
@@ -1372,25 +1372,31 @@ PySequence_Fast(PyObject *v, const char *m)
        return v;
 }
 
-/* Return # of times o appears in s. */
+/* Iterate over seq.  Result depends on the operation:
+   PY_ITERSEARCH_COUNT:  -1 if error, else # of times obj appears in seq.
+   PY_ITERSEARCH_INDEX:  0-based index of first occurence of obj in seq;
+       set ValueError and return -1 if none found; also return -1 on error.
+   Py_ITERSEARCH_CONTAINS:  return 1 if obj in seq, else 0; -1 on error.
+*/
 int
-PySequence_Count(PyObject *s, PyObject *o)
+_PySequence_IterSearch(PyObject *seq, PyObject *obj, int operation)
 {
-       int n;  /* running count of o hits */
-       PyObject *it;  /* iter(s) */
+       int n;
+       int wrapped;  /* for PY_ITERSEARCH_INDEX, true iff n wrapped around */
+       PyObject *it;  /* iter(seq) */
 
-       if (s == NULL || o == NULL) {
+       if (seq == NULL || obj == NULL) {
                null_error();
                return -1;
        }
 
-       it = PyObject_GetIter(s);
+       it = PyObject_GetIter(seq);
        if (it == NULL) {
-               type_error(".count() requires iterable argument");
+               type_error("iterable argument required");
                return -1;
        }
 
-       n = 0;
+       n = wrapped = 0;
        for (;;) {
                int cmp;
                PyObject *item = PyIter_Next(it);
@@ -1399,61 +1405,70 @@ PySequence_Count(PyObject *s, PyObject *o)
                                goto Fail;
                        break;
                }
-               cmp = PyObject_RichCompareBool(o, item, Py_EQ);
+
+               cmp = PyObject_RichCompareBool(obj, item, Py_EQ);
                Py_DECREF(item);
                if (cmp < 0)
                        goto Fail;
                if (cmp > 0) {
-                       if (n == INT_MAX) {
-                               PyErr_SetString(PyExc_OverflowError,
+                       switch (operation) {
+                       case PY_ITERSEARCH_COUNT:
+                               ++n;
+                               if (n <= 0) {
+                                       PyErr_SetString(PyExc_OverflowError,
                                                "count exceeds C int size");
-                               goto Fail;
+                                       goto Fail;
+                               }
+                               break;
+
+                       case PY_ITERSEARCH_INDEX:
+                               if (wrapped) {
+                                       PyErr_SetString(PyExc_OverflowError,
+                                               "index exceeds C int size");
+                                       goto Fail;
+                               }
+                               goto Done;
+
+                       case PY_ITERSEARCH_CONTAINS:
+                               n = 1;
+                               goto Done;
+
+                       default:
+                               assert(!"unknown operation");
                        }
-                       n++;
+               }
+
+               if (operation == PY_ITERSEARCH_INDEX) {
+                       ++n;
+                       if (n <= 0)
+                               wrapped = 1;
                }
        }
-       Py_DECREF(it);
-       return n;
 
+       if (operation != PY_ITERSEARCH_INDEX)
+               goto Done;
+
+       PyErr_SetString(PyExc_ValueError,
+                       "sequence.index(x): x not in sequence");
+       /* fall into failure code */
 Fail:
+       n = -1;
+       /* fall through */
+Done:
        Py_DECREF(it);
-       return -1;
+       return n;
+
 }
 
-/* Return -1 if error; 1 if ob in seq; 0 if ob not in seq.
- * Always uses the iteration protocol, and only Py_EQ comparison.
- */
+/* Return # of times o appears in s. */
 int
-_PySequence_IterContains(PyObject *seq, PyObject *ob)
+PySequence_Count(PyObject *s, PyObject *o)
 {
-       int result;
-       PyObject *it = PyObject_GetIter(seq);
-       if (it == NULL) {
-               PyErr_SetString(PyExc_TypeError,
-                       "'in' or 'not in' needs iterable right argument");
-               return -1;
-       }
-
-       for (;;) {
-               int cmp;
-               PyObject *item = PyIter_Next(it);
-               if (item == NULL) {
-                       result = PyErr_Occurred() ? -1 : 0;
-                       break;
-               }
-               cmp = PyObject_RichCompareBool(ob, item, Py_EQ);
-               Py_DECREF(item);
-               if (cmp == 0)
-                       continue;
-               result = cmp > 0 ? 1 : -1;
-               break;
-       }
-       Py_DECREF(it);
-       return result;
+       return _PySequence_IterSearch(s, o, PY_ITERSEARCH_COUNT);
 }
 
 /* Return -1 if error; 1 if ob in seq; 0 if ob not in seq.
- * Use sq_contains if possible, else defer to _PySequence_IterContains().
+ * Use sq_contains if possible, else defer to _PySequence_IterSearch().
  */
 int
 PySequence_Contains(PyObject *seq, PyObject *ob)
@@ -1463,7 +1478,7 @@ PySequence_Contains(PyObject *seq, PyObject *ob)
                if (sqm != NULL && sqm->sq_contains != NULL)
                        return (*sqm->sq_contains)(seq, ob);
        }
-       return _PySequence_IterContains(seq, ob);
+       return _PySequence_IterSearch(seq, ob, PY_ITERSEARCH_CONTAINS);
 }
 
 /* Backwards compatibility */
@@ -1477,32 +1492,7 @@ PySequence_In(PyObject *w, PyObject *v)
 int
 PySequence_Index(PyObject *s, PyObject *o)
 {
-       int l, i, cmp, err;
-       PyObject *item;
-
-       if (s == NULL || o == NULL) {
-               null_error();
-               return -1;
-       }
-       
-       l = PySequence_Size(s);
-       if (l < 0)
-               return -1;
-
-       for (i = 0; i < l; i++) {
-               item = PySequence_GetItem(s, i);
-               if (item == NULL)
-                       return -1;
-               err = PyObject_Cmp(item, o, &cmp);
-               Py_DECREF(item);
-               if (err < 0)
-                       return err;
-               if (cmp == 0)
-                       return i;
-       }
-
-       PyErr_SetString(PyExc_ValueError, "sequence.index(x): x not in list");
-       return -1;
+       return _PySequence_IterSearch(s, o, PY_ITERSEARCH_INDEX);
 }
 
 /* Operations on mappings */
index 4b698425324663a179987c54d2154df3751d61dc..9d84173b0b38e3d13d156691a82be652f69dbb16 100644 (file)
@@ -1224,7 +1224,8 @@ instance_contains(PyInstanceObject *inst, PyObject *member)
                 * __contains__ attribute, and try iterating instead.
                 */
                PyErr_Clear();
-               return _PySequence_IterContains((PyObject *)inst, member);
+               return _PySequence_IterSearch((PyObject *)inst, member,
+                                             PY_ITERSEARCH_CONTAINS);
        }
        else
                return -1;
index f15b096580bf8faf66014530f0b48c08784a6c4a..430e68ca0a084d3caf49c8164aeb4a2efaf9263d 100644 (file)
@@ -2559,7 +2559,8 @@ slot_sq_contains(PyObject *self, PyObject *value)
        }
        else {
                PyErr_Clear();
-               return _PySequence_IterContains(self, value);
+               return _PySequence_IterSearch(self, value,
+                                             PY_ITERSEARCH_CONTAINS);
        }
 }