/*
 * stat functions
 */

#include "tsvector.h"
#include "ts_stat.h"
#include "funcapi.h"
#include "catalog/pg_type.h"
#include "executor/spi.h"
#include "common.h"
#include "ts_locale.h"

PG_FUNCTION_INFO_V1(tsstat_in);
Datum		tsstat_in(PG_FUNCTION_ARGS);
Datum
tsstat_in(PG_FUNCTION_ARGS)
{
	tsstat	   *stat = palloc(STATHDRSIZE);

	stat->len = STATHDRSIZE;
	stat->size = 0;
	stat->weight = 0;
	PG_RETURN_POINTER(stat);
}

PG_FUNCTION_INFO_V1(tsstat_out);
Datum		tsstat_out(PG_FUNCTION_ARGS);
Datum
tsstat_out(PG_FUNCTION_ARGS)
{
	ereport(ERROR,
			(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
			 errmsg("tsstat_out not implemented")));
	PG_RETURN_NULL();
}

static int
check_weight(tsvector * txt, WordEntry * wptr, int8 weight)
{
	int			len = POSDATALEN(txt, wptr);
	int			num = 0;
	WordEntryPos *ptr = POSDATAPTR(txt, wptr);

	while (len--)
	{
		if (weight & (1 << WEP_GETWEIGHT(*ptr)))
			num++;
		ptr++;
	}
	return num;
}

static WordEntry **
SEI_realloc(WordEntry ** in, uint32 *len)
{
	if (*len == 0 || in == NULL)
	{
		*len = 8;
		in = palloc(sizeof(WordEntry *) * (*len));
	}
	else
	{
		*len *= 2;
		in = repalloc(in, sizeof(WordEntry *) * (*len));
	}
	return in;
}

static int
compareStatWord(StatEntry * a, WordEntry * b, tsstat * stat, tsvector * txt)
{
	if (a->len == b->len)
		return strncmp(
					   STATSTRPTR(stat) + a->pos,
					   STRPTR(txt) + b->pos,
					   a->len
			);
	return (a->len > b->len) ? 1 : -1;
}

static tsstat *
formstat(tsstat * stat, tsvector * txt, WordEntry ** entry, uint32 len)
{
	tsstat	   *newstat;
	uint32		totallen,
				nentry;
	uint32		slen = 0;
	WordEntry **ptr = entry;
	char	   *curptr;
	StatEntry  *sptr,
			   *nptr;

	while (ptr - entry < len)
	{
		slen += (*ptr)->len;
		ptr++;
	}

	nentry = stat->size + len;
	slen += STATSTRSIZE(stat);
	totallen = CALCSTATSIZE(nentry, slen);
	newstat = palloc(totallen);
	newstat->len = totallen;
	newstat->weight = stat->weight;
	newstat->size = nentry;

	memcpy(STATSTRPTR(newstat), STATSTRPTR(stat), STATSTRSIZE(stat));
	curptr = STATSTRPTR(newstat) + STATSTRSIZE(stat);

	ptr = entry;
	sptr = STATPTR(stat);
	nptr = STATPTR(newstat);

	if (len == 1)
	{
		StatEntry  *StopLow = STATPTR(stat);
		StatEntry  *StopHigh = (StatEntry *) STATSTRPTR(stat);

		while (StopLow < StopHigh)
		{
			sptr = StopLow + (StopHigh - StopLow) / 2;
			if (compareStatWord(sptr, *ptr, stat, txt) < 0)
				StopLow = sptr + 1;
			else
				StopHigh = sptr;
		}
		nptr = STATPTR(newstat) + (StopLow - STATPTR(stat));
		memcpy(STATPTR(newstat), STATPTR(stat), sizeof(StatEntry) * (StopLow - STATPTR(stat)));
		if ((*ptr)->haspos)
			nptr->nentry = (stat->weight) ? check_weight(txt, *ptr, stat->weight) : POSDATALEN(txt, *ptr);
		else
			nptr->nentry = 1;
		nptr->ndoc = 1;
		nptr->len = (*ptr)->len;
		memcpy(curptr, STRPTR(txt) + (*ptr)->pos, nptr->len);
		nptr->pos = curptr - STATSTRPTR(newstat);
		memcpy(nptr + 1, StopLow, sizeof(StatEntry) * (((StatEntry *) STATSTRPTR(stat)) - StopLow));
	}
	else
	{
		while (sptr - STATPTR(stat) < stat->size && ptr - entry < len)
		{
			if (compareStatWord(sptr, *ptr, stat, txt) < 0)
			{
				memcpy(nptr, sptr, sizeof(StatEntry));
				sptr++;
			}
			else
			{
				if ((*ptr)->haspos)
					nptr->nentry = (stat->weight) ? check_weight(txt, *ptr, stat->weight) : POSDATALEN(txt, *ptr);
				else
					nptr->nentry = 1;
				nptr->ndoc = 1;
				nptr->len = (*ptr)->len;
				memcpy(curptr, STRPTR(txt) + (*ptr)->pos, nptr->len);
				nptr->pos = curptr - STATSTRPTR(newstat);
				curptr += nptr->len;
				ptr++;
			}
			nptr++;
		}

		memcpy(nptr, sptr, sizeof(StatEntry) * (stat->size - (sptr - STATPTR(stat))));

		while (ptr - entry < len)
		{
			if ((*ptr)->haspos)
				nptr->nentry = (stat->weight) ? check_weight(txt, *ptr, stat->weight) : POSDATALEN(txt, *ptr);
			else
				nptr->nentry = 1;
			nptr->ndoc = 1;
			nptr->len = (*ptr)->len;
			memcpy(curptr, STRPTR(txt) + (*ptr)->pos, nptr->len);
			nptr->pos = curptr - STATSTRPTR(newstat);
			curptr += nptr->len;
			ptr++;
			nptr++;
		}
	}

	return newstat;
}

PG_FUNCTION_INFO_V1(ts_accum);
Datum		ts_accum(PG_FUNCTION_ARGS);
Datum
ts_accum(PG_FUNCTION_ARGS)
{
	tsstat	   *newstat,
			   *stat = (tsstat *) PG_GETARG_POINTER(0);
	tsvector   *txt = (tsvector *) PG_DETOAST_DATUM(PG_GETARG_DATUM(1));
	WordEntry **newentry = NULL;
	uint32		len = 0,
				cur = 0;
	StatEntry  *sptr;
	WordEntry  *wptr;
	int			n = 0;

	if (stat == NULL || PG_ARGISNULL(0))
	{							/* Init in first */
		stat = palloc(STATHDRSIZE);
		stat->len = STATHDRSIZE;
		stat->size = 0;
		stat->weight = 0;
	}

	/* simple check of correctness */
	if (txt == NULL || PG_ARGISNULL(1) || txt->size == 0)
	{
		PG_FREE_IF_COPY(txt, 1);
		PG_RETURN_POINTER(stat);
	}

	sptr = STATPTR(stat);
	wptr = ARRPTR(txt);

	if (stat->size < 100 * txt->size)
	{							/* merge */
		while (sptr - STATPTR(stat) < stat->size && wptr - ARRPTR(txt) < txt->size)
		{
			int			cmp = compareStatWord(sptr, wptr, stat, txt);

			if (cmp < 0)
				sptr++;
			else if (cmp == 0)
			{
				if (stat->weight == 0)
				{
					sptr->ndoc++;
					sptr->nentry += (wptr->haspos) ? POSDATALEN(txt, wptr) : 1;
				}
				else if (wptr->haspos && (n = check_weight(txt, wptr, stat->weight)) != 0)
				{
					sptr->ndoc++;
					sptr->nentry += n;
				}
				sptr++;
				wptr++;
			}
			else
			{
				if (stat->weight == 0 || check_weight(txt, wptr, stat->weight) != 0)
				{
					if (cur == len)
						newentry = SEI_realloc(newentry, &len);
					newentry[cur] = wptr;
					cur++;
				}
				wptr++;
			}
		}

		while (wptr - ARRPTR(txt) < txt->size)
		{
			if (stat->weight == 0 || check_weight(txt, wptr, stat->weight) != 0)
			{
				if (cur == len)
					newentry = SEI_realloc(newentry, &len);
				newentry[cur] = wptr;
				cur++;
			}
			wptr++;
		}
	}
	else
	{							/* search */
		while (wptr - ARRPTR(txt) < txt->size)
		{
			StatEntry  *StopLow = STATPTR(stat);
			StatEntry  *StopHigh = (StatEntry *) STATSTRPTR(stat);
			int			cmp;

			while (StopLow < StopHigh)
			{
				sptr = StopLow + (StopHigh - StopLow) / 2;
				cmp = compareStatWord(sptr, wptr, stat, txt);
				if (cmp == 0)
				{
					if (stat->weight == 0)
					{
						sptr->ndoc++;
						sptr->nentry += (wptr->haspos) ? POSDATALEN(txt, wptr) : 1;
					}
					else if (wptr->haspos && (n = check_weight(txt, wptr, stat->weight)) != 0)
					{
						sptr->ndoc++;
						sptr->nentry += n;
					}
					break;
				}
				else if (cmp < 0)
					StopLow = sptr + 1;
				else
					StopHigh = sptr;
			}

			if (StopLow >= StopHigh)
			{					/* not found */
				if (stat->weight == 0 || check_weight(txt, wptr, stat->weight) != 0)
				{
					if (cur == len)
						newentry = SEI_realloc(newentry, &len);
					newentry[cur] = wptr;
					cur++;
				}
			}
			wptr++;
		}
	}


	if (cur == 0)
	{							/* no new words */
		PG_FREE_IF_COPY(txt, 1);
		PG_RETURN_POINTER(stat);
	}

	newstat = formstat(stat, txt, newentry, cur);
	pfree(newentry);
	PG_FREE_IF_COPY(txt, 1);
	/* pfree(stat); */

	PG_RETURN_POINTER(newstat);
}

typedef struct
{
	uint32		cur;
	tsvector   *stat;
}	StatStorage;

static void
ts_setup_firstcall(FunctionCallInfo fcinfo, FuncCallContext *funcctx,
				   tsstat * stat)
{
	TupleDesc	tupdesc;
	MemoryContext oldcontext;
	StatStorage *st;

	oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
	st = palloc(sizeof(StatStorage));
	st->cur = 0;
	st->stat = palloc(stat->len);
	memcpy(st->stat, stat, stat->len);
	funcctx->user_fctx = (void *) st;
	if (get_call_result_type(fcinfo, NULL, &tupdesc) != TYPEFUNC_COMPOSITE)
		elog(ERROR, "return type must be a row type");
	tupdesc = CreateTupleDescCopy(tupdesc);
	funcctx->attinmeta = TupleDescGetAttInMetadata(tupdesc);
	MemoryContextSwitchTo(oldcontext);
}


static Datum
ts_process_call(FuncCallContext *funcctx)
{
	StatStorage *st;

	st = (StatStorage *) funcctx->user_fctx;

	if (st->cur < st->stat->size)
	{
		Datum		result;
		char	   *values[3];
		char		ndoc[16];
		char		nentry[16];
		StatEntry  *entry = STATPTR(st->stat) + st->cur;
		HeapTuple	tuple;

		values[1] = ndoc;
		sprintf(ndoc, "%d", entry->ndoc);
		values[2] = nentry;
		sprintf(nentry, "%d", entry->nentry);
		values[0] = palloc(entry->len + 1);
		memcpy(values[0], STATSTRPTR(st->stat) + entry->pos, entry->len);
		(values[0])[entry->len] = '\0';

		tuple = BuildTupleFromCStrings(funcctx->attinmeta, values);
		result = HeapTupleGetDatum(tuple);

		pfree(values[0]);
		st->cur++;
		return result;
	}
	else
	{
		pfree(st->stat);
		pfree(st);
	}

	return (Datum) 0;
}

PG_FUNCTION_INFO_V1(ts_accum_finish);
Datum		ts_accum_finish(PG_FUNCTION_ARGS);
Datum
ts_accum_finish(PG_FUNCTION_ARGS)
{
	FuncCallContext *funcctx;
	Datum		result;

	if (SRF_IS_FIRSTCALL())
	{
		funcctx = SRF_FIRSTCALL_INIT();
		ts_setup_firstcall(fcinfo, funcctx, (tsstat *) PG_GETARG_POINTER(0));
	}

	funcctx = SRF_PERCALL_SETUP();
	if ((result = ts_process_call(funcctx)) != (Datum) 0)
		SRF_RETURN_NEXT(funcctx, result);
	SRF_RETURN_DONE(funcctx);
}

static Oid	tiOid = InvalidOid;

static void
get_ti_Oid(void)
{
	int			ret;
	bool		isnull;

	if ((ret = SPI_exec("select oid from pg_type where typname='tsvector'", 1)) < 0)
		/* internal error */
		elog(ERROR, "SPI_exec to get tsvector oid returns %d", ret);

	if (SPI_processed < 1)
		/* internal error */
		elog(ERROR, "there is no tsvector type");
	tiOid = DatumGetObjectId(SPI_getbinval(SPI_tuptable->vals[0], SPI_tuptable->tupdesc, 1, &isnull));
	if (tiOid == InvalidOid)
		/* internal error */
		elog(ERROR, "tsvector type has InvalidOid");
}

static tsstat *
ts_stat_sql(text *txt, text *ws)
{
	char	   *query = text2char(txt);
	int			i;
	tsstat	   *newstat,
			   *stat;
	bool		isnull;
	Portal		portal;
	void	   *plan;

	if (tiOid == InvalidOid)
		get_ti_Oid();

	if ((plan = SPI_prepare(query, 0, NULL)) == NULL)
		/* internal error */
		elog(ERROR, "SPI_prepare('%s') returns NULL", query);

	if ((portal = SPI_cursor_open(NULL, plan, NULL, NULL, false)) == NULL)
		/* internal error */
		elog(ERROR, "SPI_cursor_open('%s') returns NULL", query);

	SPI_cursor_fetch(portal, true, 100);

	if (SPI_tuptable->tupdesc->natts != 1)
		/* internal error */
		elog(ERROR, "number of fields doesn't equal to 1");

	if (SPI_gettypeid(SPI_tuptable->tupdesc, 1) != tiOid)
		/* internal error */
		elog(ERROR, "column isn't of tsvector type");

	stat = palloc(STATHDRSIZE);
	stat->len = STATHDRSIZE;
	stat->size = 0;
	stat->weight = 0;

	if (ws)
	{
		char	   *buf;

		buf = VARDATA(ws);
		while (buf - VARDATA(ws) < VARSIZE(ws) - VARHDRSZ)
		{
			if ( pg_mblen(buf) == 1 ) {
				switch (*buf)
				{
					case 'A':
					case 'a':
						stat->weight |= 1 << 3;
						break;
					case 'B':
					case 'b':
						stat->weight |= 1 << 2;
						break;
					case 'C':
					case 'c':
						stat->weight |= 1 << 1;
						break;
					case 'D':
					case 'd':
						stat->weight |= 1;
						break;
					default:
						stat->weight |= 0;
				}
			}
			buf+=pg_mblen(buf);
		}
	}

	while (SPI_processed > 0)
	{
		for (i = 0; i < SPI_processed; i++)
		{
			Datum		data = SPI_getbinval(SPI_tuptable->vals[i], SPI_tuptable->tupdesc, 1, &isnull);

			if (!isnull)
			{
				newstat = (tsstat *) DatumGetPointer(DirectFunctionCall2(
																	ts_accum,
													   PointerGetDatum(stat),
																		 data
																		 ));
				if (stat != newstat && stat)
					pfree(stat);
				stat = newstat;
			}
		}

		SPI_freetuptable(SPI_tuptable);
		SPI_cursor_fetch(portal, true, 100);
	}

	SPI_freetuptable(SPI_tuptable);
	SPI_cursor_close(portal);
	SPI_freeplan(plan);
	pfree(query);

	return stat;
}

PG_FUNCTION_INFO_V1(ts_stat);
Datum		ts_stat(PG_FUNCTION_ARGS);
Datum
ts_stat(PG_FUNCTION_ARGS)
{
	FuncCallContext *funcctx;
	Datum		result;

	if (SRF_IS_FIRSTCALL())
	{
		tsstat	   *stat;
		text	   *txt = PG_GETARG_TEXT_P(0);
		text	   *ws = (PG_NARGS() > 1) ? PG_GETARG_TEXT_P(1) : NULL;

		funcctx = SRF_FIRSTCALL_INIT();
		SPI_connect();
		stat = ts_stat_sql(txt, ws);
		PG_FREE_IF_COPY(txt, 0);
		if (PG_NARGS() > 1)
			PG_FREE_IF_COPY(ws, 1);
		ts_setup_firstcall(fcinfo, funcctx, stat);
		SPI_finish();
	}

	funcctx = SRF_PERCALL_SETUP();
	if ((result = ts_process_call(funcctx)) != (Datum) 0)
		SRF_RETURN_NEXT(funcctx, result);
	SRF_RETURN_DONE(funcctx);
}