]> granicus.if.org Git - postgresql/commitdiff
Add defenses against integer overflow in dynahash numbuckets calculations.
authorTom Lane <tgl@sss.pgh.pa.us>
Wed, 12 Dec 2012 03:09:05 +0000 (22:09 -0500)
committerTom Lane <tgl@sss.pgh.pa.us>
Wed, 12 Dec 2012 03:09:05 +0000 (22:09 -0500)
The dynahash code requires the number of buckets in a hash table to fit
in an int; but since we calculate the desired hash table size dynamically,
there are various scenarios where we might calculate too large a value.
The resulting overflow can lead to infinite loops, division-by-zero
crashes, etc.  I (tgl) had previously installed some defenses against that
in commit 299d1716525c659f0e02840e31fbe4dea3, but that covered only one
call path.  Moreover it worked by limiting the request size to work_mem,
but in a 64-bit machine it's possible to set work_mem high enough that the
problem appears anyway.  So let's fix the problem at the root by installing
limits in the dynahash.c functions themselves.

Trouble report and patch by Jeff Davis.

src/backend/executor/nodeHash.c
src/backend/utils/hash/dynahash.c

index c90fe40b3c9b21073943b2b32117069052917e87..5d0fc77c3015a739e2270d16a704818aa3bbedc9 100644 (file)
@@ -500,7 +500,9 @@ ExecChooseHashTableSize(double ntuples, int tupwidth, bool useskew,
         * Both nbuckets and nbatch must be powers of 2 to make
         * ExecHashGetBucketAndBatch fast.      We already fixed nbatch; now inflate
         * nbuckets to the next larger power of 2.      We also force nbuckets to not
-        * be real small, by starting the search at 2^10.
+        * be real small, by starting the search at 2^10.  (Note: above we made
+        * sure that nbuckets is not more than INT_MAX / 2, so this loop cannot
+        * overflow, nor can the final shift to recalculate nbuckets.)
         */
        i = 10;
        while ((1 << i) < nbuckets)
index 31ac2b25e8ffa2f50dfa4c9bde0d248baf2c0979..07f7a84943029f876abcd853c7b2b85bdd0e048a 100644 (file)
@@ -68,6 +68,8 @@
 
 #include "postgres.h"
 
+#include <limits.h>
+
 #include "access/xact.h"
 #include "storage/shmem.h"
 #include "storage/spin.h"
@@ -205,6 +207,8 @@ static void hdefault(HTAB *hashp);
 static int     choose_nelem_alloc(Size entrysize);
 static bool init_htab(HTAB *hashp, long nelem);
 static void hash_corrupted(HTAB *hashp);
+static long next_pow2_long(long num);
+static int     next_pow2_int(long num);
 static void register_seq_scan(HTAB *hashp);
 static void deregister_seq_scan(HTAB *hashp);
 static bool has_seq_scans(HTAB *hashp);
@@ -379,8 +383,13 @@ hash_create(const char *tabname, long nelem, HASHCTL *info, int flags)
        {
                /* Doesn't make sense to partition a local hash table */
                Assert(flags & HASH_SHARED_MEM);
-               /* # of partitions had better be a power of 2 */
-               Assert(info->num_partitions == (1L << my_log2(info->num_partitions)));
+
+               /*
+                * The number of partitions had better be a power of 2. Also, it must
+                * be less than INT_MAX (see init_htab()), so call the int version of
+                * next_pow2.
+                */
+               Assert(info->num_partitions == next_pow2_int(info->num_partitions));
 
                hctl->num_partitions = info->num_partitions;
        }
@@ -523,7 +532,6 @@ init_htab(HTAB *hashp, long nelem)
 {
        HASHHDR    *hctl = hashp->hctl;
        HASHSEGMENT *segp;
-       long            lnbuckets;
        int                     nbuckets;
        int                     nsegs;
 
@@ -538,9 +546,7 @@ init_htab(HTAB *hashp, long nelem)
         * number of buckets.  Allocate space for the next greater power of two
         * number of buckets
         */
-       lnbuckets = (nelem - 1) / hctl->ffactor + 1;
-
-       nbuckets = 1 << my_log2(lnbuckets);
+       nbuckets = next_pow2_int((nelem - 1) / hctl->ffactor + 1);
 
        /*
         * In a partitioned table, nbuckets must be at least equal to
@@ -558,7 +564,7 @@ init_htab(HTAB *hashp, long nelem)
         * Figure number of directory segments needed, round up to a power of 2
         */
        nsegs = (nbuckets - 1) / hctl->ssize + 1;
-       nsegs = 1 << my_log2(nsegs);
+       nsegs = next_pow2_int(nsegs);
 
        /*
         * Make sure directory is big enough. If pre-allocated directory is too
@@ -628,9 +634,9 @@ hash_estimate_size(long num_entries, Size entrysize)
                                elementAllocCnt;
 
        /* estimate number of buckets wanted */
-       nBuckets = 1L << my_log2((num_entries - 1) / DEF_FFACTOR + 1);
+       nBuckets = next_pow2_long((num_entries - 1) / DEF_FFACTOR + 1);
        /* # of segments needed for nBuckets */
-       nSegments = 1L << my_log2((nBuckets - 1) / DEF_SEGSIZE + 1);
+       nSegments = next_pow2_long((nBuckets - 1) / DEF_SEGSIZE + 1);
        /* directory entries */
        nDirEntries = DEF_DIRSIZE;
        while (nDirEntries < nSegments)
@@ -671,9 +677,9 @@ hash_select_dirsize(long num_entries)
                                nDirEntries;
 
        /* estimate number of buckets wanted */
-       nBuckets = 1L << my_log2((num_entries - 1) / DEF_FFACTOR + 1);
+       nBuckets = next_pow2_long((num_entries - 1) / DEF_FFACTOR + 1);
        /* # of segments needed for nBuckets */
-       nSegments = 1L << my_log2((nBuckets - 1) / DEF_SEGSIZE + 1);
+       nSegments = next_pow2_long((nBuckets - 1) / DEF_SEGSIZE + 1);
        /* directory entries */
        nDirEntries = DEF_DIRSIZE;
        while (nDirEntries < nSegments)
@@ -1408,11 +1414,32 @@ my_log2(long num)
        int                     i;
        long            limit;
 
+       /* guard against too-large input, which would put us into infinite loop */
+       if (num > LONG_MAX / 2)
+               num = LONG_MAX / 2;
+
        for (i = 0, limit = 1; limit < num; i++, limit <<= 1)
                ;
        return i;
 }
 
+/* calculate first power of 2 >= num, bounded to what will fit in a long */
+static long
+next_pow2_long(long num)
+{
+       /* my_log2's internal range check is sufficient */
+       return 1L << my_log2(num);
+}
+
+/* calculate first power of 2 >= num, bounded to what will fit in an int */
+static int
+next_pow2_int(long num)
+{
+       if (num > INT_MAX / 2)
+               num = INT_MAX / 2;
+       return 1 << my_log2(num);
+}
+
 
 /************************* SEQ SCAN TRACKING ************************/