diff --git a/utils/rowgroup/rowaggregation.cpp b/utils/rowgroup/rowaggregation.cpp index e550a3f04..f42c89bbd 100644 --- a/utils/rowgroup/rowaggregation.cpp +++ b/utils/rowgroup/rowaggregation.cpp @@ -714,11 +714,8 @@ void RowAggregation::setJoinRowGroups(vector* pSmallSideRG, RowGroup* // threads on the PM and by multple threads on the UM. It must remain // thread safe. //------------------------------------------------------------------------------ -void RowAggregation::resetUDAF(uint64_t funcColID) +void RowAggregation::resetUDAF(RowUDAFFunctionCol* rowUDAF) { - // Get the UDAF class pointer and store in the row definition object. - RowUDAFFunctionCol* rowUDAF = dynamic_cast(fFunctionCols[funcColID].get()); - // RowAggregation and it's functions need to be re-entrant which means // each instance (thread) needs its own copy of the context object. // Note: operator=() doesn't copy userData. @@ -786,7 +783,7 @@ void RowAggregation::initialize() { if (fFunctionCols[i]->fAggFunction == ROWAGG_UDAF) { - resetUDAF(i); + resetUDAF(dynamic_cast(fFunctionCols[i].get())); } } } @@ -838,7 +835,7 @@ void RowAggregation::aggReset() { if (fFunctionCols[i]->fAggFunction == ROWAGG_UDAF) { - resetUDAF(i); + resetUDAF(dynamic_cast(fFunctionCols[i].get())); } } } @@ -885,14 +882,28 @@ void RowAggregationUM::aggregateRowWithRemap(Row& row) inserted.first->second = RowPosition(fResultDataVec.size() - 1, fRowGroupOut->getRowCount() - 1); // If there's UDAF involved, reset the user data. - for (uint64_t i = 0; i < fFunctionCols.size(); i++) + if (fOrigFunctionCols) { - if (fFunctionCols[i]->fAggFunction == ROWAGG_UDAF) + // This is a multi-distinct query and fFunctionCols may not + // contain all the UDAF we need to reset + for (uint64_t i = 0; i < fOrigFunctionCols->size(); i++) { - resetUDAF(i); + if ((*fOrigFunctionCols)[i]->fAggFunction == ROWAGG_UDAF) + { + resetUDAF(dynamic_cast((*fOrigFunctionCols)[i].get())); + } + } + } + else + { + for (uint64_t i = 0; i < fFunctionCols.size(); i++) + { + if (fFunctionCols[i]->fAggFunction == ROWAGG_UDAF) + { + resetUDAF(dynamic_cast(fFunctionCols[i].get())); + } } } - // replace the key value with an equivalent copy, yes this is OK const_cast((inserted.first->first)) = pos; } @@ -946,14 +957,28 @@ void RowAggregation::aggregateRow(Row& row) RowPosition(fResultDataVec.size() - 1, fRowGroupOut->getRowCount() - 1); // If there's UDAF involved, reset the user data. - for (uint64_t i = 0; i < fFunctionCols.size(); i++) + if (fOrigFunctionCols) { - if (fFunctionCols[i]->fAggFunction == ROWAGG_UDAF) + // This is a multi-distinct query and fFunctionCols may not + // contain all the UDAF we need to reset + for (uint64_t i = 0; i < fOrigFunctionCols->size(); i++) { - resetUDAF(i); + if ((*fOrigFunctionCols)[i]->fAggFunction == ROWAGG_UDAF) + { + resetUDAF(dynamic_cast((*fOrigFunctionCols)[i].get())); + } + } + } + else + { + for (uint64_t i = 0; i < fFunctionCols.size(); i++) + { + if (fFunctionCols[i]->fAggFunction == ROWAGG_UDAF) + { + resetUDAF(dynamic_cast(fFunctionCols[i].get())); + } } } - } else { @@ -4699,7 +4724,7 @@ void RowAggregationMultiDistinct::doDistinctAggregation() { // backup the function column vector for finalize(). vector origFunctionCols = fFunctionCols; - + fOrigFunctionCols = &origFunctionCols; // aggregate data from each sub-aggregator to distinct aggregator for (uint64_t i = 0; i < fSubAggregators.size(); ++i) { @@ -4727,6 +4752,7 @@ void RowAggregationMultiDistinct::doDistinctAggregation() // restore the function column vector fFunctionCols = origFunctionCols; + fOrigFunctionCols = NULL; } @@ -4734,7 +4760,8 @@ void RowAggregationMultiDistinct::doDistinctAggregation_rowVec(vector origFunctionCols = fFunctionCols; - + fOrigFunctionCols = &origFunctionCols; + // aggregate data from each sub-aggregator to distinct aggregator for (uint64_t i = 0; i < fSubAggregators.size(); ++i) { @@ -4751,9 +4778,9 @@ void RowAggregationMultiDistinct::doDistinctAggregation_rowVec(vectorclear(); } - void resetUDAF(uint64_t funcColID); + void resetUDAF(RowUDAFFunctionCol* rowUDAF); inline bool isNull(const RowGroup* pRowGroup, const Row& row, int64_t col); inline void makeAggFieldsNull(Row& row); @@ -710,6 +710,9 @@ protected: static const static_any::any& doubleTypeId; static const static_any::any& longdoubleTypeId; static const static_any::any& strTypeId; + + // For UDAF along with with multiple distinct columns + vector* fOrigFunctionCols = NULL; }; //------------------------------------------------------------------------------