1
0
mirror of https://github.com/sqlite/sqlite.git synced 2025-07-30 19:03:16 +03:00

Add support making use of sqlite3_aggregate_context() (in a roundabout way) from Java to accumulate state within aggregate and window UDFs.

FossilOrigin-Name: 640574984741c7a9472d7f8be7bce87e736d7947ce673ae4a25008d74238ad90
This commit is contained in:
stephan
2023-07-28 01:12:47 +00:00
parent 8ba5d79c35
commit 48a8352a39
6 changed files with 214 additions and 42 deletions

View File

@ -271,10 +271,11 @@ enum {
typedef struct NphCacheLine NphCacheLine;
struct NphCacheLine {
const char * zClassName /* "full/class/Name" */;
jclass klazz /* global ref to concrete NPH class */;
jmethodID midSet /* setNativePointer() */;
jmethodID midGet /* getNativePointer() */;
jmethodID midCtor /* constructor */;
jclass klazz /* global ref to concrete NPH class */;
jmethodID midSet /* setNativePointer() */;
jmethodID midGet /* getNativePointer() */;
jmethodID midCtor /* constructor */;
jmethodID midSetAgg /* sqlite3_context::setAggregateContext() */;
};
typedef struct JNIEnvCacheLine JNIEnvCacheLine;
@ -713,6 +714,42 @@ static void * getNativePointer(JNIEnv * env, jobject pObj, const char *zClassNam
}
}
/**
Requires that jCx be a Java-side sqlite3_context wrapper for pCx.
This function calls sqlite3_aggregate_context() to allocate a tiny
sliver of memory, the address of which is set in
jCx->setAggregateContext(). The memory is only used as a key for
mapping, client-side, results of aggregate result sets across
xStep() and xFinal() methods.
isFinal must be 1 for xFinal() calls and 0 for all others.
*/
static void setAggregateContext(JNIEnv * env, jobject jCx,
sqlite3_context * pCx,
int isFinal){
jmethodID setter;
void * pAgg;
struct NphCacheLine * const cacheLine =
S3Global_nph_cache(env, ClassNames.sqlite3_context);
if(cacheLine && cacheLine->klazz && cacheLine->midSetAgg){
setter = cacheLine->midSetAgg;
assert(setter);
}else{
jclass const klazz =
cacheLine ? cacheLine->klazz : (*env)->GetObjectClass(env, jCx);
setter = (*env)->GetMethodID(env, klazz, "setAggregateContext", "(J)V");
if(cacheLine){
assert(cacheLine->klazz);
assert(!cacheLine->midSetAgg);
cacheLine->midSetAgg = setter;
}
}
pAgg = sqlite3_aggregate_context(pCx, isFinal ? 0 : 8);
(*env)->CallVoidMethod(env, jCx, setter, (jlong)pAgg);
IFTHREW_REPORT;
}
/*
** This function is NOT part of the sqlite3 public API. It is strictly
** for use by the sqlite project's own Java/JNI bindings.
@ -1054,6 +1091,11 @@ typedef struct {
jobjectArray jargv;
} udf_jargs;
/**
Converts the given (cx, argc, argv) into arguments for the given
UDF, placing the result in the final argument. Returns 0 on
success, SQLITE_NOMEM on allocation error.
*/
static int udf_args(sqlite3_context * const cx,
int argc, sqlite3_value**argv,
UDFState * const s,
@ -1102,19 +1144,23 @@ static int udf_report_exception(sqlite3_context * cx, UDFState *s,
return rc;
}
static int udf_xFSI(sqlite3_context* cx, int argc,
static int udf_xFSI(sqlite3_context* pCx, int argc,
sqlite3_value** argv,
UDFState * s,
jmethodID xMethodID,
const char * zFuncType){
udf_jargs args;
JNIEnv * const env = s->env;
int rc = udf_args(cx, argc, argv, s, &args);
int rc = udf_args(pCx, argc, argv, s, &args);
//MARKER(("%s.%s() pCx = %p\n", s->zFuncName, zFuncType, pCx));
if(rc) return rc;
//MARKER(("UDF::%s.%s()\n", s->zFuncName, zFuncType));
if( UDF_SCALAR != s->type ){
setAggregateContext(env, args.jcx, pCx, 0);
}
(*env)->CallVoidMethod(env, s->jObj, xMethodID, args.jcx, args.jargv);
IFTHREW{
rc = udf_report_exception(cx,s, zFuncType);
rc = udf_report_exception(pCx,s, zFuncType);
}
UNREF_L(args.jcx);
UNREF_L(args.jargv);
@ -1127,11 +1173,15 @@ static int udf_xFV(sqlite3_context* cx, UDFState * s,
JNIEnv * const env = s->env;
jobject jcx = new_sqlite3_context_wrapper(s->env, cx);
int rc = 0;
//MARKER(("%s.%s() cx = %p\n", s->zFuncName, zFuncType, cx));
if(!jcx){
sqlite3_result_error_nomem(cx);
return SQLITE_NOMEM;
}
//MARKER(("UDF::%s.%s()\n", s->zFuncName, zFuncType));
if( UDF_SCALAR != s->type ){
setAggregateContext(env, jcx, cx, 1);
}
(*env)->CallVoidMethod(env, s->jObj, xMethodID, jcx);
IFTHREW{
rc = udf_report_exception(cx,s, zFuncType);

View File

@ -19,9 +19,62 @@ package org.sqlite.jni;
access to the callback functions needed in order to implement SQL
functions in Java. This class is not used by itself: see the
three inner classes.
Note that if a given function is called multiple times in a single
SQL statement, e.g. SELECT MYFUNC(A), MYFUNC(B)..., then the
context object passed to each one will be different. This is most
significant for aggregates and window functions, since they must
assign their results to the proper context.
TODO: add helper APIs to map sqlite3_context instances to
func-specific state and to clear that when the aggregate or window
function is done.
*/
public abstract class SQLFunction {
/**
ContextMap is a helper for use with aggregate and window
functions, to help them manage their accumulator state across
calls to xStep() and xFinal(). It works by mapping
sqlite3_context::getAggregateContext() to a single piece of state
which persists across a set of 0 or more SQLFunction.xStep()
calls and 1 SQLFunction.xFinal() call.
*/
public static final class ContextMap<T> {
private java.util.Map<Long,ValueHolder<T>> map
= new java.util.HashMap<Long,ValueHolder<T>>();
/**
Should be called from a UDF's xStep() method, passing it that
method's first argument and an initial value for the persistent
state. If there is currently no mapping for
cx.getAggregateContext() within the map, one is created, else
an existing one is preferred. It returns a ValueHolder which
can be used to modify that state directly without having to put
a new result back in the underlying map.
*/
public ValueHolder<T> xStep(sqlite3_context cx, T initialValue){
ValueHolder<T> rc = map.get(cx.getAggregateContext());
if(null == rc){
map.put(cx.getAggregateContext(), rc = new ValueHolder<T>(initialValue));
}
return rc;
}
/**
Should be called from a UDF's xFinal() method and passed that
method's first argument. This function returns the value
associated with cx.getAggregateContext(), or null if
this.xStep() has not been called to set up such a mapping. That
will be the case if an aggregate is used in a statement which
has no result rows.
*/
public T xFinal(sqlite3_context cx){
final ValueHolder<T> h = map.remove(cx.getAggregateContext());
return null==h ? null : h.value;
}
}
//! Subclass for creating scalar functions.
public static abstract class Scalar extends SQLFunction {
public abstract void xFunc(sqlite3_context cx, sqlite3_value[] args);
@ -33,18 +86,36 @@ public abstract class SQLFunction {
}
//! Subclass for creating aggregate functions.
public static abstract class Aggregate extends SQLFunction {
public static abstract class Aggregate<T> extends SQLFunction {
public abstract void xStep(sqlite3_context cx, sqlite3_value[] args);
public abstract void xFinal(sqlite3_context cx);
public void xDestroy() {}
private final ContextMap<T> map = new ContextMap<>();
/**
See ContextMap<T>.xStep().
*/
public final ValueHolder<T> getAggregateState(sqlite3_context cx, T initialValue){
return map.xStep(cx, initialValue);
}
/**
See ContextMap<T>.xFinal().
*/
public final T takeAggregateState(sqlite3_context cx){
return map.xFinal(cx);
}
}
//! Subclass for creating window functions.
public static abstract class Window extends SQLFunction {
public abstract void xStep(sqlite3_context cx, sqlite3_value[] args);
public static abstract class Window<T> extends Aggregate<T> {
public Window(){
super();
}
//public abstract void xStep(sqlite3_context cx, sqlite3_value[] args);
public abstract void xInverse(sqlite3_context cx, sqlite3_value[] args);
public abstract void xFinal(sqlite3_context cx);
//public abstract void xFinal(sqlite3_context cx);
public abstract void xValue(sqlite3_context cx);
public void xDestroy() {}
}
}

View File

@ -482,21 +482,23 @@ public class Tester1 {
private static void testUdfAggregate(){
final sqlite3 db = createNewDb();
SQLFunction func = new SQLFunction.Aggregate(){
private int accum = 0;
@Override public void xStep(sqlite3_context cx, sqlite3_value args[]){
this.accum += sqlite3_value_int(args[0]);
SQLFunction func = new SQLFunction.Aggregate<Integer>(){
@Override
public void xStep(sqlite3_context cx, sqlite3_value args[]){
this.getAggregateState(cx, 0).value += sqlite3_value_int(args[0]);
}
@Override public void xFinal(sqlite3_context cx){
sqlite3_result_int(cx, this.accum);
this.accum = 0;
@Override
public void xFinal(sqlite3_context cx){
final Integer v = this.takeAggregateState(cx);
if(null == v) sqlite3_result_null(cx);
else sqlite3_result_int(cx, v);
}
};
execSql(db, "CREATE TABLE t(a); INSERT INTO t(a) VALUES(1),(2),(3)");
int rc = sqlite3_create_function(db, "myfunc", 1, SQLITE_UTF8, func);
affirm(0 == rc);
sqlite3_stmt stmt = new sqlite3_stmt();
sqlite3_prepare(db, "select myfunc(a) from t", stmt);
sqlite3_prepare(db, "select myfunc(a), myfunc(a+10) from t", stmt);
affirm( 0 != stmt.getNativePointer() );
int n = 0;
if( SQLITE_ROW == sqlite3_step(stmt) ){
@ -514,6 +516,20 @@ public class Tester1 {
}
sqlite3_finalize(stmt);
affirm( 1==n );
rc = sqlite3_prepare(db, "select myfunc(a), myfunc(a+a) from t order by a",
stmt);
affirm( 0 == rc );
n = 0;
while( SQLITE_ROW == sqlite3_step(stmt) ){
final int c0 = sqlite3_column_int(stmt, 0);
final int c1 = sqlite3_column_int(stmt, 1);
++n;
affirm( 6 == c0 );
affirm( 12 == c1 );
}
affirm( 1 == n );
sqlite3_finalize(stmt);
sqlite3_close(db);
}
@ -521,26 +537,27 @@ public class Tester1 {
final sqlite3 db = createNewDb();
/* Example window function, table, and results taken from:
https://sqlite.org/windowfunctions.html#udfwinfunc */
final SQLFunction func = new SQLFunction.Window(){
private int accum = 0;
private void xStepInverse(int v){
this.accum += v;
}
private void xFinalValue(sqlite3_context cx){
sqlite3_result_int(cx, this.accum);
final SQLFunction func = new SQLFunction.Window<Integer>(){
private void xStepInverse(sqlite3_context cx, int v){
this.getAggregateState(cx,0).value += v;
}
@Override public void xStep(sqlite3_context cx, sqlite3_value[] args){
this.xStepInverse(sqlite3_value_int(args[0]));
this.xStepInverse(cx, sqlite3_value_int(args[0]));
}
@Override public void xInverse(sqlite3_context cx, sqlite3_value[] args){
this.xStepInverse(-sqlite3_value_int(args[0]));
this.xStepInverse(cx, -sqlite3_value_int(args[0]));
}
private void xFinalValue(sqlite3_context cx, Integer v){
if(null == v) sqlite3_result_null(cx);
else sqlite3_result_int(cx, v);
}
@Override public void xFinal(sqlite3_context cx){
this.xFinalValue(cx);
this.accum = 0;
xFinalValue(cx, this.takeAggregateState(cx));
}
@Override public void xValue(sqlite3_context cx){
this.xFinalValue(cx);
xFinalValue(cx, this.getAggregateState(cx,null).value);
}
};
int rc = sqlite3_create_function(db, "winsumint", 1, SQLITE_UTF8, func);

View File

@ -13,8 +13,42 @@
*/
package org.sqlite.jni;
/**
sqlite3_context instances are used in conjunction with user-defined
SQL functions (a.k.a. UDFs). They are opaque pointers.
The getAggregateContext() method corresponds to C's
sqlite3_aggregate_context(), with a slightly different interface in
order to account for cross-language differences. It serves the same
purposes in a slightly different way: it provides a key which is
stable across invocations of UDF xStep() and xFinal() pairs, to
which a UDF may map state across such calls (e.g. a numeric result
which is being accumulated).
*/
public class sqlite3_context extends NativePointerHolder<sqlite3_context> {
public sqlite3_context() {
super();
}
private long aggcx = 0;
/**
If this object is being used in the context of an aggregate or
window UDF, the UDF binding layer will set a unique context value
here. That value will be the same across matching calls to the
xStep() and xFinal() routines, as well as xValue() and xInverse()
in window UDFs. This value can be used as a key to map state
which needs to persist across such calls, noting that such state
should be cleaned up via xFinal().
*/
public long getAggregateContext(){
return aggcx;
}
/**
For use only by the JNI layer. It's permitted to call this even
though it's private.
*/
private void setAggregateContext(long n){
aggcx = n;
}
}