]> granicus.if.org Git - python/commitdiff
improve type-safe of and prevent double-frees in get_locale_info (#28119)
authorBenjamin Peterson <benjamin@python.org>
Wed, 14 Sep 2016 05:43:45 +0000 (22:43 -0700)
committerBenjamin Peterson <benjamin@python.org>
Wed, 14 Sep 2016 05:43:45 +0000 (22:43 -0700)
Python/formatter_unicode.c

index e7c6a4f1a5f8dfb764c84518ce0fae6befe69fb1..617d58b2072d1008ec82e7ebe690922cab378176 100644 (file)
@@ -347,9 +347,11 @@ fill_padding(_PyUnicodeWriter *writer,
 /************************************************************************/
 
 /* Locale type codes. */
-#define LT_CURRENT_LOCALE 0
-#define LT_DEFAULT_LOCALE 1
-#define LT_NO_LOCALE 2
+enum LocaleType {
+    LT_CURRENT_LOCALE,
+    LT_DEFAULT_LOCALE,
+    LT_NO_LOCALE
+};
 
 /* Locale info needed for formatting integers and the part of floats
    before and including the decimal. Note that locales only support
@@ -663,7 +665,7 @@ static char no_grouping[1] = {CHAR_MAX};
    LT_CURRENT_LOCALE, a hard-coded locale if LT_DEFAULT_LOCALE, or
    none if LT_NO_LOCALE. */
 static int
-get_locale_info(int type, LocaleInfo *locale_info)
+get_locale_info(enum LocaleType type, LocaleInfo *locale_info)
 {
     switch (type) {
     case LT_CURRENT_LOCALE: {
@@ -676,21 +678,16 @@ get_locale_info(int type, LocaleInfo *locale_info)
         locale_info->thousands_sep = PyUnicode_DecodeLocale(
                                          locale_data->thousands_sep,
                                          NULL);
-        if (locale_info->thousands_sep == NULL) {
-            Py_DECREF(locale_info->decimal_point);
+        if (locale_info->thousands_sep == NULL)
             return -1;
-        }
         locale_info->grouping = locale_data->grouping;
         break;
     }
     case LT_DEFAULT_LOCALE:
         locale_info->decimal_point = PyUnicode_FromOrdinal('.');
         locale_info->thousands_sep = PyUnicode_FromOrdinal(',');
-        if (!locale_info->decimal_point || !locale_info->thousands_sep) {
-            Py_XDECREF(locale_info->decimal_point);
-            Py_XDECREF(locale_info->thousands_sep);
+        if (!locale_info->decimal_point || !locale_info->thousands_sep)
             return -1;
-        }
         locale_info->grouping = "\3"; /* Group every 3 characters.  The
                                          (implicit) trailing 0 means repeat
                                          infinitely. */
@@ -698,15 +695,10 @@ get_locale_info(int type, LocaleInfo *locale_info)
     case LT_NO_LOCALE:
         locale_info->decimal_point = PyUnicode_FromOrdinal('.');
         locale_info->thousands_sep = PyUnicode_New(0, 0);
-        if (!locale_info->decimal_point || !locale_info->thousands_sep) {
-            Py_XDECREF(locale_info->decimal_point);
-            Py_XDECREF(locale_info->thousands_sep);
+        if (!locale_info->decimal_point || !locale_info->thousands_sep)
             return -1;
-        }
         locale_info->grouping = no_grouping;
         break;
-    default:
-        assert(0);
     }
     return 0;
 }