1
0
mirror of https://github.com/postgres/postgres.git synced 2025-07-03 20:02:46 +03:00

Add width_bucket(anyelement, anyarray).

This provides a convenient method of classifying input values into buckets
that are not necessarily equal-width.  It works on any sortable data type.

The choice of function name is a bit debatable, perhaps, but showing that
there's a relationship to the SQL standard's width_bucket() function seems
more attractive than the other proposals.

Petr Jelinek, reviewed by Pavel Stehule
This commit is contained in:
Tom Lane
2014-09-09 15:34:10 -04:00
parent 220bb39dee
commit e80252d424
7 changed files with 458 additions and 12 deletions

View File

@ -15,8 +15,10 @@
#include "postgres.h"
#include <ctype.h>
#include <math.h>
#include "access/htup_details.h"
#include "catalog/pg_type.h"
#include "funcapi.h"
#include "libpq/pqformat.h"
#include "utils/array.h"
@ -130,6 +132,15 @@ static ArrayType *array_replace_internal(ArrayType *array,
Datum replace, bool replace_isnull,
bool remove, Oid collation,
FunctionCallInfo fcinfo);
static int width_bucket_array_float8(Datum operand, ArrayType *thresholds);
static int width_bucket_array_fixed(Datum operand,
ArrayType *thresholds,
Oid collation,
TypeCacheEntry *typentry);
static int width_bucket_array_variable(Datum operand,
ArrayType *thresholds,
Oid collation,
TypeCacheEntry *typentry);
/*
@ -5502,3 +5513,235 @@ array_replace(PG_FUNCTION_ARGS)
fcinfo);
PG_RETURN_ARRAYTYPE_P(array);
}
/*
* Implements width_bucket(anyelement, anyarray).
*
* 'thresholds' is an array containing lower bound values for each bucket;
* these must be sorted from smallest to largest, or bogus results will be
* produced. If N thresholds are supplied, the output is from 0 to N:
* 0 is for inputs < first threshold, N is for inputs >= last threshold.
*/
Datum
width_bucket_array(PG_FUNCTION_ARGS)
{
Datum operand = PG_GETARG_DATUM(0);
ArrayType *thresholds = PG_GETARG_ARRAYTYPE_P(1);
Oid collation = PG_GET_COLLATION();
Oid element_type = ARR_ELEMTYPE(thresholds);
int result;
/* Check input */
if (ARR_NDIM(thresholds) > 1)
ereport(ERROR,
(errcode(ERRCODE_ARRAY_SUBSCRIPT_ERROR),
errmsg("thresholds must be one-dimensional array")));
if (array_contains_nulls(thresholds))
ereport(ERROR,
(errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
errmsg("thresholds array must not contain NULLs")));
/* We have a dedicated implementation for float8 data */
if (element_type == FLOAT8OID)
result = width_bucket_array_float8(operand, thresholds);
else
{
TypeCacheEntry *typentry;
/* Cache information about the input type */
typentry = (TypeCacheEntry *) fcinfo->flinfo->fn_extra;
if (typentry == NULL ||
typentry->type_id != element_type)
{
typentry = lookup_type_cache(element_type,
TYPECACHE_CMP_PROC_FINFO);
if (!OidIsValid(typentry->cmp_proc_finfo.fn_oid))
ereport(ERROR,
(errcode(ERRCODE_UNDEFINED_FUNCTION),
errmsg("could not identify a comparison function for type %s",
format_type_be(element_type))));
fcinfo->flinfo->fn_extra = (void *) typentry;
}
/*
* We have separate implementation paths for fixed- and variable-width
* types, since indexing the array is a lot cheaper in the first case.
*/
if (typentry->typlen > 0)
result = width_bucket_array_fixed(operand, thresholds,
collation, typentry);
else
result = width_bucket_array_variable(operand, thresholds,
collation, typentry);
}
/* Avoid leaking memory when handed toasted input. */
PG_FREE_IF_COPY(thresholds, 1);
PG_RETURN_INT32(result);
}
/*
* width_bucket_array for float8 data.
*/
static int
width_bucket_array_float8(Datum operand, ArrayType *thresholds)
{
float8 op = DatumGetFloat8(operand);
float8 *thresholds_data;
int left;
int right;
/*
* Since we know the array contains no NULLs, we can just index it
* directly.
*/
thresholds_data = (float8 *) ARR_DATA_PTR(thresholds);
left = 0;
right = ArrayGetNItems(ARR_NDIM(thresholds), ARR_DIMS(thresholds));
/*
* If the probe value is a NaN, it's greater than or equal to all possible
* threshold values (including other NaNs), so we need not search. Note
* that this would give the same result as searching even if the array
* contains multiple NaNs (as long as they're correctly sorted), since the
* loop logic will find the rightmost of multiple equal threshold values.
*/
if (isnan(op))
return right;
/* Find the bucket */
while (left < right)
{
int mid = (left + right) / 2;
if (isnan(thresholds_data[mid]) || op < thresholds_data[mid])
right = mid;
else
left = mid + 1;
}
return left;
}
/*
* width_bucket_array for generic fixed-width data types.
*/
static int
width_bucket_array_fixed(Datum operand,
ArrayType *thresholds,
Oid collation,
TypeCacheEntry *typentry)
{
char *thresholds_data;
int typlen = typentry->typlen;
bool typbyval = typentry->typbyval;
FunctionCallInfoData locfcinfo;
int left;
int right;
/*
* Since we know the array contains no NULLs, we can just index it
* directly.
*/
thresholds_data = (char *) ARR_DATA_PTR(thresholds);
InitFunctionCallInfoData(locfcinfo, &typentry->cmp_proc_finfo, 2,
collation, NULL, NULL);
/* Find the bucket */
left = 0;
right = ArrayGetNItems(ARR_NDIM(thresholds), ARR_DIMS(thresholds));
while (left < right)
{
int mid = (left + right) / 2;
char *ptr;
int32 cmpresult;
ptr = thresholds_data + mid * typlen;
locfcinfo.arg[0] = operand;
locfcinfo.arg[1] = fetch_att(ptr, typbyval, typlen);
locfcinfo.argnull[0] = false;
locfcinfo.argnull[1] = false;
locfcinfo.isnull = false;
cmpresult = DatumGetInt32(FunctionCallInvoke(&locfcinfo));
if (cmpresult < 0)
right = mid;
else
left = mid + 1;
}
return left;
}
/*
* width_bucket_array for generic variable-width data types.
*/
static int
width_bucket_array_variable(Datum operand,
ArrayType *thresholds,
Oid collation,
TypeCacheEntry *typentry)
{
char *thresholds_data;
int typlen = typentry->typlen;
bool typbyval = typentry->typbyval;
char typalign = typentry->typalign;
FunctionCallInfoData locfcinfo;
int left;
int right;
thresholds_data = (char *) ARR_DATA_PTR(thresholds);
InitFunctionCallInfoData(locfcinfo, &typentry->cmp_proc_finfo, 2,
collation, NULL, NULL);
/* Find the bucket */
left = 0;
right = ArrayGetNItems(ARR_NDIM(thresholds), ARR_DIMS(thresholds));
while (left < right)
{
int mid = (left + right) / 2;
char *ptr;
int i;
int32 cmpresult;
/* Locate mid'th array element by advancing from left element */
ptr = thresholds_data;
for (i = left; i < mid; i++)
{
ptr = att_addlength_pointer(ptr, typlen, ptr);
ptr = (char *) att_align_nominal(ptr, typalign);
}
locfcinfo.arg[0] = operand;
locfcinfo.arg[1] = fetch_att(ptr, typbyval, typlen);
locfcinfo.argnull[0] = false;
locfcinfo.argnull[1] = false;
locfcinfo.isnull = false;
cmpresult = DatumGetInt32(FunctionCallInvoke(&locfcinfo));
if (cmpresult < 0)
right = mid;
else
{
left = mid + 1;
/*
* Move the thresholds pointer to match new "left" index, so we
* don't have to seek over those elements again. This trick
* ensures we do only O(N) array indexing work, not O(N^2).
*/
ptr = att_addlength_pointer(ptr, typlen, ptr);
thresholds_data = (char *) att_align_nominal(ptr, typalign);
}
}
return left;
}