1
0
mirror of https://github.com/mariadb-corporation/mariadb-columnstore-engine.git synced 2025-07-29 08:21:15 +03:00

Welford algorithm for STD and VAR

Naive algorithm for calculating STD and VAR is subject to catastrophic
cancellation. A well-known Welford's algorithms is used instead.
This commit is contained in:
Andrey Piskunov
2022-06-01 19:02:24 +03:00
parent 4e50fca460
commit c5fa27475d
8 changed files with 1003 additions and 40 deletions

View File

@ -1900,8 +1900,8 @@ void RowAggregation::doAvg(const Row& rowIn, int64_t colIn, int64_t colOut, int6
// rowIn(in) - Row to be included in aggregation.
// colIn(in) - column in the input row group
// colOut(in) - column in the output row group stores the count
// colAux(in) - column in the output row group stores the sum(x)
// colAux + 1 - column in the output row group stores the sum(x**2)
// colAux(in) - column in the output row group stores the mean(x)
// colAux + 1 - column in the output row group stores the sum(x_i - mean)^2
//------------------------------------------------------------------------------
void RowAggregation::doStatistics(const Row& rowIn, int64_t colIn, int64_t colOut, int64_t colAux)
{
@ -1960,9 +1960,17 @@ void RowAggregation::doStatistics(const Row& rowIn, int64_t colIn, int64_t colOu
break;
}
fRow.setDoubleField(fRow.getDoubleField(colOut) + 1.0, colOut);
fRow.setLongDoubleField(fRow.getLongDoubleField(colAux) + valIn, colAux);
fRow.setLongDoubleField(fRow.getLongDoubleField(colAux + 1) + valIn * valIn, colAux + 1);
double count = fRow.getDoubleField(colOut) + 1.0;
long double mean = fRow.getLongDoubleField(colAux);
long double M2 = fRow.getLongDoubleField(colAux + 1);
volatile long double delta = valIn - mean;
mean += delta/count;
M2 += delta * (valIn - mean);
fRow.setDoubleField(count, colOut);
fRow.setLongDoubleField(mean, colAux);
fRow.setLongDoubleField(M2, colAux + 1);
}
void RowAggregation::mergeStatistics(const Row& rowIn, uint64_t colOut, uint64_t colAux)
@ -3156,31 +3164,26 @@ void RowAggregationUM::calculateStatisticsFunctions()
}
else // count > 1
{
long double sum1 = fRow.getLongDoubleField(colAux);
long double sum2 = fRow.getLongDoubleField(colAux + 1);
long double M2 = fRow.getLongDoubleField(colAux + 1);
uint32_t scale = fRow.getScale(colOut);
auto factor = datatypes::scaleDivisor<long double>(scale);
if (scale != 0) // adjust the scale if necessary
{
sum1 /= factor;
sum2 /= factor * factor;
M2 /= factor * factor;
}
long double stat = sum1 * sum1 / cnt;
stat = sum2 - stat;
if (fFunctionCols[i]->fStatsFunction == ROWAGG_STDDEV_POP)
stat = sqrt(stat / cnt);
M2 = sqrt(M2 / cnt);
else if (fFunctionCols[i]->fStatsFunction == ROWAGG_STDDEV_SAMP)
stat = sqrt(stat / (cnt - 1));
M2 = sqrt(M2 / (cnt - 1));
else if (fFunctionCols[i]->fStatsFunction == ROWAGG_VAR_POP)
stat = stat / cnt;
M2 = M2 / cnt;
else if (fFunctionCols[i]->fStatsFunction == ROWAGG_VAR_SAMP)
stat = stat / (cnt - 1);
M2 = M2 / (cnt - 1);
fRow.setDoubleField(stat, colOut);
fRow.setDoubleField(M2, colOut);
}
}
}
@ -4281,18 +4284,39 @@ void RowAggregationUMP2::doAvg(const Row& rowIn, int64_t colIn, int64_t colOut,
// Update the sum and count fields for stattistics if input is not null.
// rowIn(in) - Row to be included in aggregation.
// colIn(in) - column in the input row group stores the count/logical block
// colIn + 1 - column in the input row group stores the sum(x)/logical block
// colIn + 2 - column in the input row group stores the sum(x**2)/logical block
// colIn + 1 - column in the input row group stores the mean(x)/logical block
// colIn + 2 - column in the input row group stores the sum(x_i - mean)^2/logical block
// colOut(in) - column in the output row group stores the count
// colAux(in) - column in the output row group stores the sum(x)
// colAux + 1 - column in the output row group stores the sum(x**2)
// colAux(in) - column in the output row group stores the mean(x)
// colAux + 1 - column in the output row group stores the sum(x_i - mean)^2
//------------------------------------------------------------------------------
void RowAggregationUMP2::doStatistics(const Row& rowIn, int64_t colIn, int64_t colOut, int64_t colAux)
{
fRow.setDoubleField(fRow.getDoubleField(colOut) + rowIn.getDoubleField(colIn), colOut);
fRow.setLongDoubleField(fRow.getLongDoubleField(colAux) + rowIn.getLongDoubleField(colIn + 1), colAux);
fRow.setLongDoubleField(fRow.getLongDoubleField(colAux + 1) + rowIn.getLongDoubleField(colIn + 2),
colAux + 1);
double count = fRow.getDoubleField(colOut);
long double mean = fRow.getLongDoubleField(colAux);
long double M2 = fRow.getLongDoubleField(colAux + 1);
double block_count = rowIn.getDoubleField(colIn);
long double block_mean = rowIn.getLongDoubleField(colIn + 1);
long double block_M2 = rowIn.getLongDoubleField(colIn + 2);
double next_count = count + block_count;
long double next_mean;
long double next_M2;
if (next_count == 0)
{
next_mean = 0;
next_M2 = 0;
}
else
{
volatile long double delta = mean - block_mean;
next_mean = (mean * count + block_mean * block_count) / next_count;
next_M2 = M2 + block_M2 + delta * delta * (count * block_count / next_count);
}
fRow.setDoubleField(next_count, colOut);
fRow.setLongDoubleField(next_mean, colAux);
fRow.setLongDoubleField(next_M2, colAux + 1);
}
//------------------------------------------------------------------------------