Optimise numeric multiplication using base-NBASE^2 arithmetic.
authorDean Rasheed <dean.a.rasheed@gmail.com>
Thu, 15 Aug 2024 09:36:17 +0000 (10:36 +0100)
committerDean Rasheed <dean.a.rasheed@gmail.com>
Thu, 15 Aug 2024 09:36:17 +0000 (10:36 +0100)
Currently mul_var() uses the schoolbook multiplication algorithm,
which is O(n^2) in the number of NBASE digits. To improve performance
for large inputs, convert the inputs to base NBASE^2 before
multiplying, which effectively halves the number of digits in each
input, theoretically speeding up the computation by a factor of 4. In
practice, the actual speedup for large inputs varies between around 3
and 6 times, depending on the system and compiler used. In turn, this
significantly reduces the runtime of the numeric_big regression test.

For this to work, 64-bit integers are required for the products of
base-NBASE^2 digits, so this works best on 64-bit machines, on which
it is faster whenever the shorter input has more than 4 or 5 NBASE
digits. On 32-bit machines, the additional overheads, especially
during carry propagation and the final conversion back to base-NBASE,
are significantly higher, and it is only faster when the shorter input
has more than around 50 NBASE digits. When the shorter input has more
than 6 NBASE digits (so that mul_var_short() cannot be used), but
fewer than around 50 NBASE digits, there may be a noticeable slowdown
on 32-bit machines. That seems to be an acceptable tradeoff, given the
performance gains for other inputs, and the effort that would be
required to maintain code specifically targeting 32-bit machines.

Joel Jacobson and Dean Rasheed.

Discussion: https://postgr.es/m/9d8a4a42-c354-41f3-bbf3-199e1957db97%40app.fastmail.com

src/backend/utils/adt/numeric.c

index 2a74312d354c67f7f6c58e188977326d4248bc99..77f64331f3659627d1d78b35bb734c20db640f8e 100644 (file)
@@ -101,6 +101,8 @@ typedef signed char NumericDigit;
 typedef int16 NumericDigit;
 #endif
 
+#define NBASE_SQR      (NBASE * NBASE)
+
 /*
  * The Numeric type as stored on disk.
  *
@@ -8668,21 +8670,30 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
                int rscale)
 {
        int                     res_ndigits;
+       int                     res_ndigitpairs;
        int                     res_sign;
        int                     res_weight;
+       int                     pair_offset;
        int                     maxdigits;
-       int                *dig;
-       int                     carry;
-       int                     maxdig;
-       int                     newdig;
+       int                     maxdigitpairs;
+       uint64     *dig,
+                          *dig_i1_off;
+       uint64          maxdig;
+       uint64          carry;
+       uint64          newdig;
        int                     var1ndigits;
        int                     var2ndigits;
+       int                     var1ndigitpairs;
+       int                     var2ndigitpairs;
        NumericDigit *var1digits;
        NumericDigit *var2digits;
+       uint32          var1digitpair;
+       uint32     *var2digitpairs;
        NumericDigit *res_digits;
        int                     i,
                                i1,
-                               i2;
+                               i2,
+                               i2limit;
 
        /*
         * Arrange for var1 to be the shorter of the two numbers.  This improves
@@ -8723,86 +8734,164 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
                return;
        }
 
-       /* Determine result sign and (maximum possible) weight */
+       /* Determine result sign */
        if (var1->sign == var2->sign)
                res_sign = NUMERIC_POS;
        else
                res_sign = NUMERIC_NEG;
-       res_weight = var1->weight + var2->weight + 2;
 
        /*
-        * Determine the number of result digits to compute.  If the exact result
-        * would have more than rscale fractional digits, truncate the computation
-        * with MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that
-        * would only contribute to the right of that.  (This will give the exact
+        * Determine the number of result digits to compute and the (maximum
+        * possible) result weight.  If the exact result would have more than
+        * rscale fractional digits, truncate the computation with
+        * MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that would
+        * only contribute to the right of that.  (This will give the exact
         * rounded-to-rscale answer unless carries out of the ignored positions
         * would have propagated through more than MUL_GUARD_DIGITS digits.)
         *
         * Note: an exact computation could not produce more than var1ndigits +
-        * var2ndigits digits, but we allocate one extra output digit in case
-        * rscale-driven rounding produces a carry out of the highest exact digit.
+        * var2ndigits digits, but we allocate at least one extra output digit in
+        * case rscale-driven rounding produces a carry out of the highest exact
+        * digit.
+        *
+        * The computation itself is done using base-NBASE^2 arithmetic, so we
+        * actually process the input digits in pairs, producing a base-NBASE^2
+        * intermediate result.  This significantly improves performance, since
+        * schoolbook multiplication is O(N^2) in the number of input digits, and
+        * working in base NBASE^2 effectively halves "N".
+        *
+        * Note: in a truncated computation, we must compute at least one extra
+        * output digit to ensure that all the guard digits are fully computed.
         */
-       res_ndigits = var1ndigits + var2ndigits + 1;
+       /* digit pairs in each input */
+       var1ndigitpairs = (var1ndigits + 1) / 2;
+       var2ndigitpairs = (var2ndigits + 1) / 2;
+
+       /* digits in exact result */
+       res_ndigits = var1ndigits + var2ndigits;
+
+       /* digit pairs in exact result with at least one extra output digit */
+       res_ndigitpairs = res_ndigits / 2 + 1;
+
+       /* pair offset to align result to end of dig[] */
+       pair_offset = res_ndigitpairs - var1ndigitpairs - var2ndigitpairs + 1;
+
+       /* maximum possible result weight (odd-length inputs shifted up below) */
+       res_weight = var1->weight + var2->weight + 1 + 2 * res_ndigitpairs -
+               res_ndigits - (var1ndigits & 1) - (var2ndigits & 1);
+
+       /* rscale-based truncation with at least one extra output digit */
        maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS +
                MUL_GUARD_DIGITS;
-       res_ndigits = Min(res_ndigits, maxdigits);
+       maxdigitpairs = maxdigits / 2 + 1;
+
+       res_ndigitpairs = Min(res_ndigitpairs, maxdigitpairs);
+       res_ndigits = 2 * res_ndigitpairs;
 
-       if (res_ndigits < 3)
+       /*
+        * In the computation below, digit pair i1 of var1 and digit pair i2 of
+        * var2 are multiplied and added to digit i1+i2+pair_offset of dig[]. Thus
+        * input digit pairs with index >= res_ndigitpairs - pair_offset don't
+        * contribute to the result, and can be ignored.
+        */
+       if (res_ndigitpairs <= pair_offset)
        {
                /* All input digits will be ignored; so result is zero */
                zero_var(result);
                result->dscale = rscale;
                return;
        }
+       var1ndigitpairs = Min(var1ndigitpairs, res_ndigitpairs - pair_offset);
+       var2ndigitpairs = Min(var2ndigitpairs, res_ndigitpairs - pair_offset);
 
        /*
-        * We do the arithmetic in an array "dig[]" of signed int's.  Since
-        * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
-        * to avoid normalizing carries immediately.
+        * We do the arithmetic in an array "dig[]" of unsigned 64-bit integers.
+        * Since PG_UINT64_MAX is much larger than NBASE^4, this gives us a lot of
+        * headroom to avoid normalizing carries immediately.
         *
         * maxdig tracks the maximum possible value of any dig[] entry; when this
-        * threatens to exceed INT_MAX, we take the time to propagate carries.
-        * Furthermore, we need to ensure that overflow doesn't occur during the
-        * carry propagation passes either.  The carry values could be as much as
-        * INT_MAX/NBASE, so really we must normalize when digits threaten to
-        * exceed INT_MAX - INT_MAX/NBASE.
+        * threatens to exceed PG_UINT64_MAX, we take the time to propagate
+        * carries.  Furthermore, we need to ensure that overflow doesn't occur
+        * during the carry propagation passes either.  The carry values could be
+        * as much as PG_UINT64_MAX / NBASE^2, so really we must normalize when
+        * digits threaten to exceed PG_UINT64_MAX - PG_UINT64_MAX / NBASE^2.
         *
-        * To avoid overflow in maxdig itself, it actually represents the max
-        * possible value divided by NBASE-1, ie, at the top of the loop it is
-        * known that no dig[] entry exceeds maxdig * (NBASE-1).
+        * To avoid overflow in maxdig itself, it actually represents the maximum
+        * possible value divided by NBASE^2-1, i.e., at the top of the loop it is
+        * known that no dig[] entry exceeds maxdig * (NBASE^2-1).
+        *
+        * The conversion of var1 to base NBASE^2 is done on the fly, as each new
+        * digit is required.  The digits of var2 are converted upfront, and
+        * stored at the end of dig[].  To avoid loss of precision, the input
+        * digits are aligned with the start of digit pair array, effectively
+        * shifting them up (multiplying by NBASE) if the inputs have an odd
+        * number of NBASE digits.
         */
-       dig = (int *) palloc0(res_ndigits * sizeof(int));
-       maxdig = 0;
+       dig = (uint64 *) palloc(res_ndigitpairs * sizeof(uint64) +
+                                                       var2ndigitpairs * sizeof(uint32));
+
+       /* convert var2 to base NBASE^2, shifting up if its length is odd */
+       var2digitpairs = (uint32 *) (dig + res_ndigitpairs);
+
+       for (i2 = 0; i2 < var2ndigitpairs - 1; i2++)
+               var2digitpairs[i2] = var2digits[2 * i2] * NBASE + var2digits[2 * i2 + 1];
+
+       if (2 * i2 + 1 < var2ndigits)
+               var2digitpairs[i2] = var2digits[2 * i2] * NBASE + var2digits[2 * i2 + 1];
+       else
+               var2digitpairs[i2] = var2digits[2 * i2] * NBASE;
 
        /*
-        * The least significant digits of var1 should be ignored if they don't
-        * contribute directly to the first res_ndigits digits of the result that
-        * we are computing.
+        * Start by multiplying var2 by the least significant contributing digit
+        * pair from var1, storing the results at the end of dig[], and filling
+        * the leading digits with zeros.
         *
-        * Digit i1 of var1 and digit i2 of var2 are multiplied and added to digit
-        * i1+i2+2 of the accumulator array, so we need only consider digits of
-        * var1 for which i1 <= res_ndigits - 3.
+        * The loop here is the same as the inner loop below, except that we set
+        * the results in dig[], rather than adding to them.  This is the
+        * performance bottleneck for multiplication, so we want to keep it simple
+        * enough so that it can be auto-vectorized.  Accordingly, process the
+        * digits left-to-right even though schoolbook multiplication would
+        * suggest right-to-left.  Since we aren't propagating carries in this
+        * loop, the order does not matter.
+        */
+       i1 = var1ndigitpairs - 1;
+       if (2 * i1 + 1 < var1ndigits)
+               var1digitpair = var1digits[2 * i1] * NBASE + var1digits[2 * i1 + 1];
+       else
+               var1digitpair = var1digits[2 * i1] * NBASE;
+       maxdig = var1digitpair;
+
+       i2limit = Min(var2ndigitpairs, res_ndigitpairs - i1 - pair_offset);
+       dig_i1_off = &dig[i1 + pair_offset];
+
+       memset(dig, 0, (i1 + pair_offset) * sizeof(uint64));
+       for (i2 = 0; i2 < i2limit; i2++)
+               dig_i1_off[i2] = (uint64) var1digitpair * var2digitpairs[i2];
+
+       /*
+        * Next, multiply var2 by the remaining digit pairs from var1, adding the
+        * results to dig[] at the appropriate offsets, and normalizing whenever
+        * there is a risk of any dig[] entry overflowing.
         */
-       for (i1 = Min(var1ndigits - 1, res_ndigits - 3); i1 >= 0; i1--)
+       for (i1 = i1 - 1; i1 >= 0; i1--)
        {
-               NumericDigit var1digit = var1digits[i1];
-
-               if (var1digit == 0)
+               var1digitpair = var1digits[2 * i1] * NBASE + var1digits[2 * i1 + 1];
+               if (var1digitpair == 0)
                        continue;
 
                /* Time to normalize? */
-               maxdig += var1digit;
-               if (maxdig > (INT_MAX - INT_MAX / NBASE) / (NBASE - 1))
+               maxdig += var1digitpair;
+               if (maxdig > (PG_UINT64_MAX - PG_UINT64_MAX / NBASE_SQR) / (NBASE_SQR - 1))
                {
-                       /* Yes, do it */
+                       /* Yes, do it (to base NBASE^2) */
                        carry = 0;
-                       for (i = res_ndigits - 1; i >= 0; i--)
+                       for (i = res_ndigitpairs - 1; i >= 0; i--)
                        {
                                newdig = dig[i] + carry;
-                               if (newdig >= NBASE)
+                               if (newdig >= NBASE_SQR)
                                {
-                                       carry = newdig / NBASE;
-                                       newdig -= carry * NBASE;
+                                       carry = newdig / NBASE_SQR;
+                                       newdig -= carry * NBASE_SQR;
                                }
                                else
                                        carry = 0;
@@ -8810,50 +8899,37 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
                        }
                        Assert(carry == 0);
                        /* Reset maxdig to indicate new worst-case */
-                       maxdig = 1 + var1digit;
+                       maxdig = 1 + var1digitpair;
                }
 
-               /*
-                * Add the appropriate multiple of var2 into the accumulator.
-                *
-                * As above, digits of var2 can be ignored if they don't contribute,
-                * so we only include digits for which i1+i2+2 < res_ndigits.
-                *
-                * This inner loop is the performance bottleneck for multiplication,
-                * so we want to keep it simple enough so that it can be
-                * auto-vectorized.  Accordingly, process the digits left-to-right
-                * even though schoolbook multiplication would suggest right-to-left.
-                * Since we aren't propagating carries in this loop, the order does
-                * not matter.
-                */
-               {
-                       int                     i2limit = Min(var2ndigits, res_ndigits - i1 - 2);
-                       int                *dig_i1_2 = &dig[i1 + 2];
+               /* Multiply and add */
+               i2limit = Min(var2ndigitpairs, res_ndigitpairs - i1 - pair_offset);
+               dig_i1_off = &dig[i1 + pair_offset];
 
-                       for (i2 = 0; i2 < i2limit; i2++)
-                               dig_i1_2[i2] += var1digit * var2digits[i2];
-               }
+               for (i2 = 0; i2 < i2limit; i2++)
+                       dig_i1_off[i2] += (uint64) var1digitpair * var2digitpairs[i2];
        }
 
        /*
-        * Now we do a final carry propagation pass to normalize the result, which
-        * we combine with storing the result digits into the output. Note that
-        * this is still done at full precision w/guard digits.
+        * Now we do a final carry propagation pass to normalize back to base
+        * NBASE^2, and construct the base-NBASE result digits.  Note that this is
+        * still done at full precision w/guard digits.
         */
        alloc_var(result, res_ndigits);
        res_digits = result->digits;
        carry = 0;
-       for (i = res_ndigits - 1; i >= 0; i--)
+       for (i = res_ndigitpairs - 1; i >= 0; i--)
        {
                newdig = dig[i] + carry;
-               if (newdig >= NBASE)
+               if (newdig >= NBASE_SQR)
                {
-                       carry = newdig / NBASE;
-                       newdig -= carry * NBASE;
+                       carry = newdig / NBASE_SQR;
+                       newdig -= carry * NBASE_SQR;
                }
                else
                        carry = 0;
-               res_digits[i] = newdig;
+               res_digits[2 * i + 1] = (NumericDigit) ((uint32) newdig % NBASE);
+               res_digits[2 * i] = (NumericDigit) ((uint32) newdig / NBASE);
        }
        Assert(carry == 0);