From 6ffa0000da851a50214105f44b034c50d29b6ee2 Mon Sep 17 00:00:00 2001 From: mariadb-AndreyPiskunov Date: Thu, 1 Sep 2022 20:39:59 +0300 Subject: [PATCH] Online algorithm for regr_sxy --- utils/regr/regr_sxy.cpp | 68 +++++++++++++++++++++++++++-------------- utils/regr/regr_sxy.h | 2 +- 2 files changed, 46 insertions(+), 24 deletions(-) diff --git a/utils/regr/regr_sxy.cpp b/utils/regr/regr_sxy.cpp index 6101f71c1..5354306ac 100644 --- a/utils/regr/regr_sxy.cpp +++ b/utils/regr/regr_sxy.cpp @@ -39,9 +39,9 @@ static Add_regr_sxy_ToUDAFMap addToMap; struct regr_sxy_data { uint64_t cnt; - long double sumx; - long double sumy; - long double sumxy; // sum of x * y + long double avgx; + long double avgy; + long double cxy; }; mcsv1_UDAF::ReturnCode regr_sxy::init(mcsv1Context* context, ColumnDatum* colTypes) @@ -74,9 +74,9 @@ mcsv1_UDAF::ReturnCode regr_sxy::reset(mcsv1Context* context) { struct regr_sxy_data* data = (struct regr_sxy_data*)context->getUserData()->data; data->cnt = 0; - data->sumx = 0.0; - data->sumy = 0.0; - data->sumxy = 0.0; + data->avgx = 0.0; + data->avgy = 0.0; + data->cxy = 0.0; return mcsv1_UDAF::SUCCESS; } @@ -85,15 +85,18 @@ mcsv1_UDAF::ReturnCode regr_sxy::nextValue(mcsv1Context* context, ColumnDatum* v double valy = toDouble(valsIn[0]); double valx = toDouble(valsIn[1]); struct regr_sxy_data* data = (struct regr_sxy_data*)context->getUserData()->data; - - data->sumy += valy; - - data->sumx += valx; - - data->sumxy += valx * valy; - + long double avgyPrev = data->avgy; + long double avgxPrev = data->avgx; + long double cxyPrev = data->cxy; ++data->cnt; - + uint64_t cnt = data->cnt; + long double dx = valx - avgxPrev; + avgyPrev += (valy - avgyPrev)/cnt; + avgxPrev += dx / cnt; + cxyPrev += dx * (valy - avgyPrev); + data->avgx = avgxPrev; + data->avgy = avgyPrev; + data->cxy = cxyPrev; return mcsv1_UDAF::SUCCESS; } @@ -107,10 +110,27 @@ mcsv1_UDAF::ReturnCode regr_sxy::subEvaluate(mcsv1Context* context, const UserDa struct regr_sxy_data* outData = (struct regr_sxy_data*)context->getUserData()->data; struct regr_sxy_data* inData = (struct regr_sxy_data*)userDataIn->data; - outData->sumx += inData->sumx; - outData->sumy += inData->sumy; - outData->sumxy += inData->sumxy; - outData->cnt += inData->cnt; + uint64_t outCnt = outData->cnt; + long double outAvgx = outData->avgx; + long double outAvgy = outData->avgy; + long double outCxy = outData->cxy; + + uint64_t inCnt = inData->cnt; + long double inAvgx = inData->avgx; + long double inAvgy = inData->avgy; + long double inCxy = inData->cxy; + + uint64_t resCnt = inCnt + outCnt; + long double deltax = outAvgx - inAvgx; + long double deltay = outAvgy - inAvgy; + long double resAvgx = inAvgx + deltax * outCnt / resCnt; + long double resAvgy = inAvgy + deltay * outCnt / resCnt; + long double resCxy = outCxy + inCxy + deltax * deltay * inCnt * outCnt / resCnt; + + outData->avgx = resAvgx; + outData->avgy = resAvgy; + outData->cxy = resCxy; + outData->cnt = resCnt; return mcsv1_UDAF::SUCCESS; } @@ -120,25 +140,27 @@ mcsv1_UDAF::ReturnCode regr_sxy::evaluate(mcsv1Context* context, static_any::any struct regr_sxy_data* data = (struct regr_sxy_data*)context->getUserData()->data; long double N = data->cnt; if (N > 0) + { - long double regr_sxy = (data->sumxy - ((data->sumx * data->sumy) / N)); + long double regr_sxy = data->cxy; valOut = static_cast(regr_sxy); } return mcsv1_UDAF::SUCCESS; } - +/* mcsv1_UDAF::ReturnCode regr_sxy::dropValue(mcsv1Context* context, ColumnDatum* valsDropped) { double valy = toDouble(valsDropped[0]); double valx = toDouble(valsDropped[1]); struct regr_sxy_data* data = (struct regr_sxy_data*)context->getUserData()->data; - data->sumy -= valy; + data->avgy -= valy; - data->sumx -= valx; + data->avgx -= valx; - data->sumxy -= valx * valy; + data->cxy -= valx * valy; --data->cnt; return mcsv1_UDAF::SUCCESS; } +*/ \ No newline at end of file diff --git a/utils/regr/regr_sxy.h b/utils/regr/regr_sxy.h index bdfffc14a..6bcbf6500 100644 --- a/utils/regr/regr_sxy.h +++ b/utils/regr/regr_sxy.h @@ -70,7 +70,7 @@ class regr_sxy : public mcsv1_UDAF virtual ReturnCode evaluate(mcsv1Context* context, static_any::any& valOut); - virtual ReturnCode dropValue(mcsv1Context* context, ColumnDatum* valsDropped); + // virtual ReturnCode dropValue(mcsv1Context* context, ColumnDatum* valsDropped); protected: };