]> granicus.if.org Git - python/commitdiff
Structure fields of type c_char array or c_wchar array accept bytes or
authorThomas Heller <theller@ctypes.org>
Fri, 13 Jul 2007 11:19:35 +0000 (11:19 +0000)
committerThomas Heller <theller@ctypes.org>
Fri, 13 Jul 2007 11:19:35 +0000 (11:19 +0000)
(unicode) string.

Lib/ctypes/test/test_bytes.py
Modules/_ctypes/cfield.c

index 25e017b8bf17b442763045a18d2f2b45268f044b..e6e047a82e33ccdd0371a5a69a9c27db4b7c3d32 100644 (file)
@@ -1,3 +1,4 @@
+"""Test where byte objects are accepted"""
 import unittest
 from ctypes import *
 
@@ -22,5 +23,19 @@ class BytesTest(unittest.TestCase):
         c_wchar_p("foo bar")
         c_wchar_p(b"foo bar")
 
+    def test_struct(self):
+        class X(Structure):
+            _fields_ = [("a", c_char * 3)]
+
+        X("abc")
+        X(b"abc")
+
+    def test_struct_W(self):
+        class X(Structure):
+            _fields_ = [("a", c_wchar * 3)]
+
+        X("abc")
+        X(b"abc")
+
 if __name__ == '__main__':
     unittest.main()
index a8d0d4bb833d13fc94354a56bff2af936fe884fe..8a0dfe7ce2d0c23e0ddeac2cfad9fb43e42d95aa 100644 (file)
@@ -1260,7 +1260,7 @@ U_set(void *ptr, PyObject *value, Py_ssize_t length)
        /* It's easier to calculate in characters than in bytes */
        length /= sizeof(wchar_t);
 
-       if (PyString_Check(value)) {
+       if (PyBytes_Check(value)) {
                value = PyUnicode_FromEncodedObject(value,
                                                    conversion_mode_encoding,
                                                    conversion_mode_errors);
@@ -1322,7 +1322,23 @@ s_set(void *ptr, PyObject *value, Py_ssize_t length)
        char *data;
        Py_ssize_t size;
 
-       data = PyString_AsString(value);
+       if (PyUnicode_Check(value)) {
+               value = PyUnicode_AsEncodedString(value,
+                                                 conversion_mode_encoding,
+                                                 conversion_mode_errors);
+               if (value == NULL)
+                       return NULL;
+               assert(PyBytes_Check(value));
+       } else if(PyBytes_Check(value)) {
+               Py_INCREF(value);
+       } else {
+               PyErr_Format(PyExc_TypeError,
+                            "expected string, %s found",
+                            value->ob_type->tp_name);
+               return NULL;
+       }
+
+       data = PyBytes_AsString(value);
        if (!data)
                return NULL;
        size = strlen(data);
@@ -1339,10 +1355,13 @@ s_set(void *ptr, PyObject *value, Py_ssize_t length)
                             "string too long (%zd, maximum length %zd)",
 #endif
                             size, length);
+               Py_DECREF(value);
                return NULL;
        }
        /* Also copy the terminating NUL character if there is space */
        memcpy((char *)ptr, data, size);
+
+       Py_DECREF(value);
        _RET(value);
 }