Don't reference out-of-bounds array elements in brin_minmax_multi.c
authorDavid Rowley <drowley@postgresql.org>
Mon, 12 Sep 2022 23:02:56 +0000 (11:02 +1200)
committerDavid Rowley <drowley@postgresql.org>
Mon, 12 Sep 2022 23:02:56 +0000 (11:02 +1200)
The primary fix here is to fix has_matching_range() so it does not
reference ranges->values[-1] when nranges == 0.  Similar problems existed
in AssertCheckRanges() too.  It does not look like any of these problems
could lead to a crash as the array in question is at the end of the Ranges
struct, and values[-1] is memory that belongs to other fields in the
struct.  However, let's get rid of these rather unsafe coding practices.

In passing, I (David) adjusted some comments to try to make it more clear
what some of the fields are for in the Ranges struct.  I had to study the
code to find out what nsorted was for as I couldn't tell from the
comments.

Author: Ranier Vilela
Discussion: https://postgr.es/m/CAEudQAqJQzPitufX-jR=YUbJafpCDAKUnwgdbX_MzSc93wuvdw@mail.gmail.com
Backpatch-through: 14, where multi-range brin was added.

src/backend/access/brin/brin_minmax_multi.c

index 4415669719f1ad332e1f6ad8f30fbc3b958b0b90..dbad7764316d09141a63ba7b5b20759f53ba1cfb 100644 (file)
@@ -142,19 +142,23 @@ typedef struct MinMaxMultiOptions
  * The Ranges struct stores the boundary values in a single array, but we
  * treat regular and single-point ranges differently to save space. For
  * regular ranges (with different boundary values) we have to store both
- * values, while for "single-point ranges" we only need to save one value.
+ * the lower and upper bound of the range, while for "single-point ranges"
+ * we only need to store a single value.
  *
  * The 'values' array stores boundary values for regular ranges first (there
  * are 2*nranges values to store), and then the nvalues boundary values for
  * single-point ranges. That is, we have (2*nranges + nvalues) boundary
  * values in the array.
  *
- * +---------------------------------+-------------------------------+
- * | ranges (sorted pairs of values) | sorted values (single points) |
- * +---------------------------------+-------------------------------+
+ * +-------------------------+----------------------------------+
+ * | ranges (2 * nranges of) | single point values (nvalues of) |
+ * +-------------------------+----------------------------------+
  *
  * This allows us to quickly add new values, and store outliers without
- * making the other ranges very wide.
+ * having to widen any of the existing range values.
+ *
+ * 'nsorted' denotes how many of 'nvalues' in the values[] array are sorted.
+ * When nsorted == nvalues, all single point values are sorted.
  *
  * We never store more than maxvalues values (as set by values_per_range
  * reloption). If needed we merge some of the ranges.
@@ -173,10 +177,10 @@ typedef struct Ranges
        FmgrInfo   *cmp;
 
        /* (2*nranges + nvalues) <= maxvalues */
-       int                     nranges;                /* number of ranges in the array (stored) */
-       int                     nsorted;                /* number of sorted values (ranges + points) */
-       int                     nvalues;                /* number of values in the data array (all) */
-       int                     maxvalues;              /* maximum number of values (reloption) */
+       int                     nranges;                /* number of ranges in the values[] array */
+       int                     nsorted;                /* number of nvalues which are sorted */
+       int                     nvalues;                /* number of point values in values[] array */
+       int                     maxvalues;              /* number of elements in the values[] array */
 
        /*
         * We simply add the values into a large buffer, without any expensive
@@ -318,102 +322,99 @@ AssertCheckRanges(Ranges *ranges, FmgrInfo *cmpFn, Oid colloid)
         * Check that none of the values are not covered by ranges (both sorted
         * and unsorted)
         */
-       for (i = 0; i < ranges->nvalues; i++)
+       if (ranges->nranges > 0)
        {
-               Datum           compar;
-               int                     start,
-                                       end;
-               Datum           minvalue,
-                                       maxvalue;
-
-               Datum           value = ranges->values[2 * ranges->nranges + i];
-
-               if (ranges->nranges == 0)
-                       break;
-
-               minvalue = ranges->values[0];
-               maxvalue = ranges->values[2 * ranges->nranges - 1];
-
-               /*
-                * Is the value smaller than the minval? If yes, we'll recurse to the
-                * left side of range array.
-                */
-               compar = FunctionCall2Coll(cmpFn, colloid, value, minvalue);
-
-               /* smaller than the smallest value in the first range */
-               if (DatumGetBool(compar))
-                       continue;
-
-               /*
-                * Is the value greater than the maxval? If yes, we'll recurse to the
-                * right side of range array.
-                */
-               compar = FunctionCall2Coll(cmpFn, colloid, maxvalue, value);
-
-               /* larger than the largest value in the last range */
-               if (DatumGetBool(compar))
-                       continue;
-
-               start = 0;                              /* first range */
-               end = ranges->nranges - 1;      /* last range */
-               while (true)
+               for (i = 0; i < ranges->nvalues; i++)
                {
-                       int                     midpoint = (start + end) / 2;
-
-                       /* this means we ran out of ranges in the last step */
-                       if (start > end)
-                               break;
+                       Datum           compar;
+                       int                     start,
+                                               end;
+                       Datum           minvalue = ranges->values[0];
+                       Datum           maxvalue = ranges->values[2 * ranges->nranges - 1];
+                       Datum           value = ranges->values[2 * ranges->nranges + i];
 
-                       /* copy the min/max values from the ranges */
-                       minvalue = ranges->values[2 * midpoint];
-                       maxvalue = ranges->values[2 * midpoint + 1];
+                       compar = FunctionCall2Coll(cmpFn, colloid, value, minvalue);
 
                        /*
-                        * Is the value smaller than the minval? If yes, we'll recurse to
-                        * the left side of range array.
+                        * If the value is smaller than the lower bound in the first range
+                        * then it cannot possibly be in any of the ranges.
                         */
-                       compar = FunctionCall2Coll(cmpFn, colloid, value, minvalue);
-
-                       /* smaller than the smallest value in this range */
                        if (DatumGetBool(compar))
-                       {
-                               end = (midpoint - 1);
                                continue;
-                       }
 
-                       /*
-                        * Is the value greater than the minval? If yes, we'll recurse to
-                        * the right side of range array.
-                        */
                        compar = FunctionCall2Coll(cmpFn, colloid, maxvalue, value);
 
-                       /* larger than the largest value in this range */
+                       /*
+                        * Likewise, if the value is larger than the upper bound of the
+                        * final range, then it cannot possibly be inside any of the
+                        * ranges.
+                        */
                        if (DatumGetBool(compar))
-                       {
-                               start = (midpoint + 1);
                                continue;
-                       }
 
-                       /* hey, we found a matching range */
-                       Assert(false);
+                       /* bsearch the ranges to see if 'value' fits within any of them */
+                       start = 0;                      /* first range */
+                       end = ranges->nranges - 1;      /* last range */
+                       while (true)
+                       {
+                               int                     midpoint = (start + end) / 2;
+
+                               /* this means we ran out of ranges in the last step */
+                               if (start > end)
+                                       break;
+
+                               /* copy the min/max values from the ranges */
+                               minvalue = ranges->values[2 * midpoint];
+                               maxvalue = ranges->values[2 * midpoint + 1];
+
+                               /*
+                                * Is the value smaller than the minval? If yes, we'll recurse
+                                * to the left side of range array.
+                                */
+                               compar = FunctionCall2Coll(cmpFn, colloid, value, minvalue);
+
+                               /* smaller than the smallest value in this range */
+                               if (DatumGetBool(compar))
+                               {
+                                       end = (midpoint - 1);
+                                       continue;
+                               }
+
+                               /*
+                                * Is the value greater than the minval? If yes, we'll recurse
+                                * to the right side of range array.
+                                */
+                               compar = FunctionCall2Coll(cmpFn, colloid, maxvalue, value);
+
+                               /* larger than the largest value in this range */
+                               if (DatumGetBool(compar))
+                               {
+                                       start = (midpoint + 1);
+                                       continue;
+                               }
+
+                               /* hey, we found a matching range */
+                               Assert(false);
+                       }
                }
        }
 
-       /* and values in the unsorted part must not be in sorted part */
-       for (i = ranges->nsorted; i < ranges->nvalues; i++)
+       /* and values in the unsorted part must not be in the sorted part */
+       if (ranges->nsorted > 0)
        {
                compare_context cxt;
-               Datum           value = ranges->values[2 * ranges->nranges + i];
-
-               if (ranges->nsorted == 0)
-                       break;
 
                cxt.colloid = ranges->colloid;
                cxt.cmpFn = ranges->cmp;
 
-               Assert(bsearch_arg(&value, &ranges->values[2 * ranges->nranges],
-                                                  ranges->nsorted, sizeof(Datum),
-                                                  compare_values, (void *) &cxt) == NULL);
+               for (i = ranges->nsorted; i < ranges->nvalues; i++)
+               {
+                       Datum           value = ranges->values[2 * ranges->nranges + i];
+
+                       Assert(bsearch_arg(&value, &ranges->values[2 * ranges->nranges],
+                                                          ranges->nsorted, sizeof(Datum),
+                                                          compare_values, (void *) &cxt) == NULL);
+               }
        }
 #endif
 }
@@ -923,8 +924,8 @@ has_matching_range(BrinDesc *bdesc, Oid colloid, Ranges *ranges,
 {
        Datum           compar;
 
-       Datum           minvalue = ranges->values[0];
-       Datum           maxvalue = ranges->values[2 * ranges->nranges - 1];
+       Datum           minvalue;
+       Datum           maxvalue;
 
        FmgrInfo   *cmpLessFn;
        FmgrInfo   *cmpGreaterFn;
@@ -936,6 +937,9 @@ has_matching_range(BrinDesc *bdesc, Oid colloid, Ranges *ranges,
        if (ranges->nranges == 0)
                return false;
 
+       minvalue = ranges->values[0];
+       maxvalue = ranges->values[2 * ranges->nranges - 1];
+
        /*
         * Otherwise, need to compare the new value with boundaries of all the
         * ranges. First check if it's less than the absolute minimum, which is