]> granicus.if.org Git - python/commitdiff
Merged revisions 65605 via svnmerge from
authorSkip Montanaro <skip@pobox.com>
Sat, 9 Aug 2008 19:44:22 +0000 (19:44 +0000)
committerSkip Montanaro <skip@pobox.com>
Sat, 9 Aug 2008 19:44:22 +0000 (19:44 +0000)
svn+ssh://pythondev@svn.python.org/python/trunk

........
  r65605 | skip.montanaro | 2008-08-08 17:52:51 -0500 (Fri, 08 Aug 2008) | 1 line

  accept issue 3436
........

Doc/library/csv.rst
Lib/csv.py
Lib/test/test_csv.py

index 9043dbc4538e0576231c7571acdafc92a2942779..588a45cedf3a98bc43934cbc4e7a113bab2af63a 100644 (file)
@@ -372,6 +372,18 @@ Reader objects have the following public attributes:
 
 
 
+DictReader objects have the following public attribute:
+
+
+.. attribute:: csvreader.fieldnames
+
+   If not passed as a parameter when creating the object, this attribute is
+   initialized upon first access or when the first record is read from the
+   file.
+
+   .. versionchanged:: 2.6
+
+
 Writer Objects
 --------------
 
index 09f4cf444a367e07af030c4b73e2c9cd36de2c80..e0558c78d464e4f46f18c19abb3ad2c63038f43a 100644 (file)
@@ -68,7 +68,7 @@ register_dialect("excel-tab", excel_tab)
 class DictReader:
     def __init__(self, f, fieldnames=None, restkey=None, restval=None,
                  dialect="excel", *args, **kwds):
-        self.fieldnames = fieldnames    # list of keys for the dict
+        self._fieldnames = fieldnames   # list of keys for the dict
         self.restkey = restkey          # key to catch long rows
         self.restval = restval          # default value for short rows
         self.reader = reader(f, dialect, *args, **kwds)
@@ -78,11 +78,25 @@ class DictReader:
     def __iter__(self):
         return self
 
+    @property
+    def fieldnames(self):
+        if self._fieldnames is None:
+            try:
+                self._fieldnames = next(self.reader)
+            except StopIteration:
+                pass
+        self.line_num = self.reader.line_num
+        return self._fieldnames
+
+    @fieldnames.setter
+    def fieldnames(self, value):
+        self._fieldnames = value
+
     def __next__(self):
+        if self.line_num == 0:
+            # Used only for its side effect.
+            self.fieldnames
         row = next(self.reader)
-        if self.fieldnames is None:
-            self.fieldnames = row
-            row = next(self.reader)
         self.line_num = self.reader.line_num
 
         # unlike the basic reader, we prefer not to return blanks,
index 1dbb71a8bbb1dd94fca647390e5bf435b22e18c6..9c9840b01873000b1bbff8aff91ea85b5942b5a8 100644 (file)
@@ -544,6 +544,29 @@ class TestDictFields(unittest.TestCase):
             fileobj.seek(0)
             reader = csv.DictReader(fileobj)
             self.assertEqual(next(reader), {"f1": '1', "f2": '2', "f3": 'abc'})
+            self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"])
+
+    # Two test cases to make sure existing ways of implicitly setting
+    # fieldnames continue to work.  Both arise from discussion in issue3436.
+    def test_read_dict_fieldnames_from_file(self):
+        with TemporaryFile("w+") as fileobj:
+            fileobj.write("f1,f2,f3\r\n1,2,abc\r\n")
+            fileobj.seek(0)
+            reader = csv.DictReader(fileobj,
+                                    fieldnames=next(csv.reader(fileobj)))
+            self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"])
+            self.assertEqual(next(reader), {"f1": '1', "f2": '2', "f3": 'abc'})
+
+    def test_read_dict_fieldnames_chain(self):
+        import itertools
+        with TemporaryFile("w+") as fileobj:
+            fileobj.write("f1,f2,f3\r\n1,2,abc\r\n")
+            fileobj.seek(0)
+            reader = csv.DictReader(fileobj)
+            first = next(reader)
+            for row in itertools.chain([first], reader):
+                self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"])
+                self.assertEqual(row, {"f1": '1', "f2": '2', "f3": 'abc'})
 
     def test_read_long(self):
         with TemporaryFile("w+") as fileobj:
@@ -568,6 +591,7 @@ class TestDictFields(unittest.TestCase):
             fileobj.write("f1,f2\r\n1,2,abc,4,5,6\r\n")
             fileobj.seek(0)
             reader = csv.DictReader(fileobj, restkey="_rest")
+            self.assertEqual(reader.fieldnames, ["f1", "f2"])
             self.assertEqual(next(reader), {"f1": '1', "f2": '2',
                                              "_rest": ["abc", "4", "5", "6"]})