From ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c Mon Sep 17 00:00:00 2001 From: Dean Rasheed Date: Tue, 9 Jul 2024 10:00:42 +0100 Subject: [PATCH] Optimise numeric multiplication for short inputs. When either input has a small number of digits, and the exact product is requested, the speed of numeric multiplication can be increased significantly by using a faster direct multiplication algorithm. This works by fully computing each result digit in turn, starting with the least significant, and propagating the carry up. This save cycles by not requiring a temporary buffer to store digit products, not making multiple passes over the digits of the longer input, and not requiring separate carry-propagation passes. For now, this is used when the shorter input has 1-4 NBASE digits (up to 13-16 decimal digits), and the longer input is of any size, which covers a lot of common real-world cases. Also, the relative benefit increases as the size of the longer input increases. Possible future work would be to try extending the technique to larger numbers of digits in the shorter input. Joel Jacobson and Dean Rasheed. Discussion: https://postgr.es/m/44d2ffca-d560-4919-b85a-4d07060946aa@app.fastmail.com --- src/backend/utils/adt/numeric.c | 220 +++++++++++++++++++++++++++++++- 1 file changed, 219 insertions(+), 1 deletion(-) diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c index 57386aabdfe..f6e20cf704c 100644 --- a/src/backend/utils/adt/numeric.c +++ b/src/backend/utils/adt/numeric.c @@ -558,6 +558,8 @@ static void sub_var(const NumericVar *var1, const NumericVar *var2, static void mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, int rscale); +static void mul_var_short(const NumericVar *var1, const NumericVar *var2, + NumericVar *result); static void div_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, int rscale, bool round); @@ -8722,7 +8724,7 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, var1digits = var1->digits; var2digits = var2->digits; - if (var1ndigits == 0 || var2ndigits == 0) + if (var1ndigits == 0) { /* one or both inputs is zero; so is result */ zero_var(result); @@ -8730,6 +8732,16 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, return; } + /* + * If var1 has 1-4 digits and the exact result was requested, delegate to + * mul_var_short() which uses a faster direct multiplication algorithm. + */ + if (var1ndigits <= 4 && rscale == var1->dscale + var2->dscale) + { + mul_var_short(var1, var2, result); + return; + } + /* Determine result sign and (maximum possible) weight */ if (var1->sign == var2->sign) res_sign = NUMERIC_POS; @@ -8880,6 +8892,212 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result, } +/* + * mul_var_short() - + * + * Special-case multiplication function used when var1 has 1-4 digits, var2 + * has at least as many digits as var1, and the exact product var1 * var2 is + * requested. + */ +static void +mul_var_short(const NumericVar *var1, const NumericVar *var2, + NumericVar *result) +{ + int var1ndigits = var1->ndigits; + int var2ndigits = var2->ndigits; + NumericDigit *var1digits = var1->digits; + NumericDigit *var2digits = var2->digits; + int res_sign; + int res_weight; + int res_ndigits; + NumericDigit *res_buf; + NumericDigit *res_digits; + uint32 carry; + uint32 term; + + /* Check preconditions */ + Assert(var1ndigits >= 1); + Assert(var1ndigits <= 4); + Assert(var2ndigits >= var1ndigits); + + /* + * Determine the result sign, weight, and number of digits to calculate. + * The weight figured here is correct if the product has no leading zero + * digits; otherwise strip_var() will fix things up. Note that, unlike + * mul_var(), we do not need to allocate an extra output digit, because we + * are not rounding here. + */ + if (var1->sign == var2->sign) + res_sign = NUMERIC_POS; + else + res_sign = NUMERIC_NEG; + res_weight = var1->weight + var2->weight + 1; + res_ndigits = var1ndigits + var2ndigits; + + /* Allocate result digit array */ + res_buf = digitbuf_alloc(res_ndigits + 1); + res_buf[0] = 0; /* spare digit for later rounding */ + res_digits = res_buf + 1; + + /* + * Compute the result digits in reverse, in one pass, propagating the + * carry up as we go. The i'th result digit consists of the sum of the + * products var1digits[i1] * var2digits[i2] for which i = i1 + i2 + 1. + */ + switch (var1ndigits) + { + case 1: + /* --------- + * 1-digit case: + * var1ndigits = 1 + * var2ndigits >= 1 + * res_ndigits = var2ndigits + 1 + * ---------- + */ + carry = 0; + for (int i = res_ndigits - 2; i >= 0; i--) + { + term = (uint32) var1digits[0] * var2digits[i] + carry; + res_digits[i + 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + } + res_digits[0] = (NumericDigit) carry; + break; + + case 2: + /* --------- + * 2-digit case: + * var1ndigits = 2 + * var2ndigits >= 2 + * res_ndigits = var2ndigits + 2 + * ---------- + */ + /* last result digit and carry */ + term = (uint32) var1digits[1] * var2digits[res_ndigits - 3]; + res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + /* remaining digits, except for the first two */ + for (int i = res_ndigits - 3; i >= 1; i--) + { + term = (uint32) var1digits[0] * var2digits[i] + + (uint32) var1digits[1] * var2digits[i - 1] + carry; + res_digits[i + 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + } + + /* first two digits */ + term = (uint32) var1digits[0] * var2digits[0] + carry; + res_digits[1] = (NumericDigit) (term % NBASE); + res_digits[0] = (NumericDigit) (term / NBASE); + break; + + case 3: + /* --------- + * 3-digit case: + * var1ndigits = 3 + * var2ndigits >= 3 + * res_ndigits = var2ndigits + 3 + * ---------- + */ + /* last two result digits */ + term = (uint32) var1digits[2] * var2digits[res_ndigits - 4]; + res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + term = (uint32) var1digits[1] * var2digits[res_ndigits - 4] + + (uint32) var1digits[2] * var2digits[res_ndigits - 5] + carry; + res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + /* remaining digits, except for the first three */ + for (int i = res_ndigits - 4; i >= 2; i--) + { + term = (uint32) var1digits[0] * var2digits[i] + + (uint32) var1digits[1] * var2digits[i - 1] + + (uint32) var1digits[2] * var2digits[i - 2] + carry; + res_digits[i + 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + } + + /* first three digits */ + term = (uint32) var1digits[0] * var2digits[1] + + (uint32) var1digits[1] * var2digits[0] + carry; + res_digits[2] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + term = (uint32) var1digits[0] * var2digits[0] + carry; + res_digits[1] = (NumericDigit) (term % NBASE); + res_digits[0] = (NumericDigit) (term / NBASE); + break; + + case 4: + /* --------- + * 4-digit case: + * var1ndigits = 4 + * var2ndigits >= 4 + * res_ndigits = var2ndigits + 4 + * ---------- + */ + /* last three result digits */ + term = (uint32) var1digits[3] * var2digits[res_ndigits - 5]; + res_digits[res_ndigits - 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + term = (uint32) var1digits[2] * var2digits[res_ndigits - 5] + + (uint32) var1digits[3] * var2digits[res_ndigits - 6] + carry; + res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + term = (uint32) var1digits[1] * var2digits[res_ndigits - 5] + + (uint32) var1digits[2] * var2digits[res_ndigits - 6] + + (uint32) var1digits[3] * var2digits[res_ndigits - 7] + carry; + res_digits[res_ndigits - 3] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + /* remaining digits, except for the first four */ + for (int i = res_ndigits - 5; i >= 3; i--) + { + term = (uint32) var1digits[0] * var2digits[i] + + (uint32) var1digits[1] * var2digits[i - 1] + + (uint32) var1digits[2] * var2digits[i - 2] + + (uint32) var1digits[3] * var2digits[i - 3] + carry; + res_digits[i + 1] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + } + + /* first four digits */ + term = (uint32) var1digits[0] * var2digits[2] + + (uint32) var1digits[1] * var2digits[1] + + (uint32) var1digits[2] * var2digits[0] + carry; + res_digits[3] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + term = (uint32) var1digits[0] * var2digits[1] + + (uint32) var1digits[1] * var2digits[0] + carry; + res_digits[2] = (NumericDigit) (term % NBASE); + carry = term / NBASE; + + term = (uint32) var1digits[0] * var2digits[0] + carry; + res_digits[1] = (NumericDigit) (term % NBASE); + res_digits[0] = (NumericDigit) (term / NBASE); + break; + } + + /* Store the product in result */ + digitbuf_free(result->buf); + result->ndigits = res_ndigits; + result->buf = res_buf; + result->digits = res_digits; + result->weight = res_weight; + result->sign = res_sign; + result->dscale = var1->dscale + var2->dscale; + + /* Strip leading and trailing zeroes */ + strip_var(result); +} + + /* * div_var() - *