From 029c5ac03db72f1898ee17e417650a2e0764b239 Mon Sep 17 00:00:00 2001
From: Peter Eisentraut <peter@eisentraut.org>
Date: Fri, 16 Apr 2021 11:46:01 +0200
Subject: [PATCH] psql: Refine lexing of BEGIN...END blocks in CREATE FUNCTION
 statements

Only track BEGIN...END blocks if they are in a CREATE [OR REPLACE]
{FUNCTION|PROCEDURE} statement.  Ignore if in parentheses.

Reviewed-by: Laurenz Albe <laurenz.albe@cybertec.at>
Discussion: https://www.postgresql.org/message-id/cee01d26fe55bc086b3bcf10bfe4e8d450e2f608.camel@cybertec.at
---
 src/fe_utils/psqlscan.l             | 53 ++++++++++++++++++++++++-----
 src/include/fe_utils/psqlscan_int.h |  8 ++++-
 2 files changed, 52 insertions(+), 9 deletions(-)

diff --git a/src/fe_utils/psqlscan.l b/src/fe_utils/psqlscan.l
index 4ec57e96a9d..991b7de0b55 100644
--- a/src/fe_utils/psqlscan.l
+++ b/src/fe_utils/psqlscan.l
@@ -870,18 +870,55 @@ other			.
 
 
 {identifier}	{
+					/*
+					 * We need to track if we are inside a BEGIN .. END block
+					 * in a function definition, so that semicolons contained
+					 * therein don't terminate the whole statement.  Short of
+					 * writing a full parser here, the following heuristic
+					 * should work.  First, we track whether the beginning of
+					 * the statement matches CREATE [OR REPLACE]
+					 * {FUNCTION|PROCEDURE}
+					 */
+
+					if (cur_state->identifier_count == 0)
+						memset(cur_state->identifiers, 0, sizeof(cur_state->identifiers));
+
+					if (pg_strcasecmp(yytext, "create") == 0 ||
+						pg_strcasecmp(yytext, "function") == 0 ||
+						pg_strcasecmp(yytext, "procedure") == 0 ||
+						pg_strcasecmp(yytext, "or") == 0 ||
+						pg_strcasecmp(yytext, "replace") == 0)
+					{
+						if (cur_state->identifier_count < sizeof(cur_state->identifiers))
+							cur_state->identifiers[cur_state->identifier_count] = pg_tolower((unsigned char) yytext[0]);
+					}
+
 					cur_state->identifier_count++;
-					if (pg_strcasecmp(yytext, "begin") == 0
-						|| pg_strcasecmp(yytext, "case") == 0)
+
+					if (cur_state->identifiers[0] == 'c' &&
+						(cur_state->identifiers[1] == 'f' || cur_state->identifiers[1] == 'p' ||
+						 (cur_state->identifiers[1] == 'o' && cur_state->identifiers[2] == 'r' &&
+						  (cur_state->identifiers[3] == 'f' || cur_state->identifiers[3] == 'p'))) &&
+						cur_state->paren_depth == 0)
 					{
-						if (cur_state->identifier_count > 1)
+						if (pg_strcasecmp(yytext, "begin") == 0)
 							cur_state->begin_depth++;
+						else if (pg_strcasecmp(yytext, "case") == 0)
+						{
+							/*
+							 * CASE also ends with END.  We only need to track
+							 * this if we are already inside a BEGIN.
+							 */
+							if (cur_state->begin_depth >= 1)
+								cur_state->begin_depth++;
+						}
+						else if (pg_strcasecmp(yytext, "end") == 0)
+						{
+							if (cur_state->begin_depth > 0)
+								cur_state->begin_depth--;
+						}
 					}
-					else if (pg_strcasecmp(yytext, "end") == 0)
-					{
-						if (cur_state->begin_depth > 0)
-							cur_state->begin_depth--;
-					}
+
 					ECHO;
 				}
 
diff --git a/src/include/fe_utils/psqlscan_int.h b/src/include/fe_utils/psqlscan_int.h
index 91d7d4d5c6c..8ada9770927 100644
--- a/src/include/fe_utils/psqlscan_int.h
+++ b/src/include/fe_utils/psqlscan_int.h
@@ -114,8 +114,14 @@ typedef struct PsqlScanStateData
 	int			paren_depth;	/* depth of nesting in parentheses */
 	int			xcdepth;		/* depth of nesting in slash-star comments */
 	char	   *dolqstart;		/* current $foo$ quote start string */
+
+	/*
+	 * State to track boundaries of BEGIN ... END blocks in function
+	 * definitions, so that semicolons do not send query too early.
+	 */
 	int			identifier_count;	/* identifiers since start of statement */
-	int			begin_depth;	/* depth of begin/end routine body blocks */
+	char		identifiers[4]; /* records the first few identifiers */
+	int			begin_depth;	/* depth of begin/end pairs */
 
 	/*
 	 * Callback functions provided by the program making use of the lexer,