Convert contrib/ltree's input functions to report errors softly
authorAndrew Dunstan <andrew@dunslane.net>
Wed, 28 Dec 2022 14:58:04 +0000 (09:58 -0500)
committerAndrew Dunstan <andrew@dunslane.net>
Wed, 28 Dec 2022 15:00:12 +0000 (10:00 -0500)
Reviewed by Tom Lane and Amul Sul

Discussion: https://postgr.es/m/49e598c2-cfe8-0928-b6fb-d0cc51aab626@dunslane.net

contrib/ltree/expected/ltree.out
contrib/ltree/ltree_io.c
contrib/ltree/ltxtquery_io.c
contrib/ltree/sql/ltree.sql

index c6d8f3ef75e6f44d3f4ae7b1128bca57a7edcb66..b95be71c781b24d7f1a7641a394df58b4609282d 100644 (file)
@@ -8084,3 +8084,28 @@ SELECT count(*) FROM _ltreetest WHERE t ? '{23.*.1,23.*.2}' ;
     15
 (1 row)
 
+-- test non-error-throwing input
+SELECT str as "value", typ as "type",
+       pg_input_is_valid(str,typ) as ok,
+       pg_input_error_message(str,typ) as errmsg
+FROM (VALUES ('.2.3', 'ltree'),
+             ('1.2.', 'ltree'),
+             ('1.2.3','ltree'),
+             ('@.2.3','lquery'),
+             (' 2.3', 'lquery'),
+             ('1.2.3','lquery'),
+             ('$tree & aWdf@*','ltxtquery'),
+             ('!tree & aWdf@*','ltxtquery'))
+      AS a(str,typ);
+     value      |   type    | ok |               errmsg               
+----------------+-----------+----+------------------------------------
+ .2.3           | ltree     | f  | ltree syntax error at character 1
+ 1.2.           | ltree     | f  | ltree syntax error
+ 1.2.3          | ltree     | t  | 
+ @.2.3          | lquery    | f  | lquery syntax error at character 1
+  2.3           | lquery    | f  | lquery syntax error at character 1
+ 1.2.3          | lquery    | t  | 
+ $tree & aWdf@* | ltxtquery | f  | operand syntax error
+ !tree & aWdf@* | ltxtquery | t  | 
+(8 rows)
+
index 15115cb29f3c3b5ab5f4f0a0416db66ce5863a49..cb9feebe78c3cbf45fb9db091a075b42f987a0dd 100644 (file)
@@ -24,8 +24,8 @@ typedef struct
 #define LTPRS_WAITNAME 0
 #define LTPRS_WAITDELIM 1
 
-static void finish_nodeitem(nodeitem *lptr, const char *ptr,
-                           bool is_lquery, int pos);
+static bool finish_nodeitem(nodeitem *lptr, const char *ptr,
+                           bool is_lquery, int pos, struct Node *escontext);
 
 
 /*
@@ -33,7 +33,7 @@ static void finish_nodeitem(nodeitem *lptr, const char *ptr,
  * returns an ltree
  */
 static ltree *
-parse_ltree(const char *buf)
+parse_ltree(const char *buf, struct Node *escontext)
 {
    const char *ptr;
    nodeitem   *list,
@@ -46,7 +46,7 @@ parse_ltree(const char *buf)
    int         charlen;
    int         pos = 1;        /* character position for error messages */
 
-#define UNCHAR ereport(ERROR, \
+#define UNCHAR ereturn(escontext, NULL,\
                       errcode(ERRCODE_SYNTAX_ERROR), \
                       errmsg("ltree syntax error at character %d", \
                              pos))
@@ -61,7 +61,7 @@ parse_ltree(const char *buf)
    }
 
    if (num + 1 > LTREE_MAX_LEVELS)
-       ereport(ERROR,
+       ereturn(escontext, NULL,
                (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
                 errmsg("number of ltree labels (%d) exceeds the maximum allowed (%d)",
                        num + 1, LTREE_MAX_LEVELS)));
@@ -86,7 +86,8 @@ parse_ltree(const char *buf)
            case LTPRS_WAITDELIM:
                if (t_iseq(ptr, '.'))
                {
-                   finish_nodeitem(lptr, ptr, false, pos);
+                   if (!finish_nodeitem(lptr, ptr, false, pos, escontext))
+                       return NULL;
                    totallen += MAXALIGN(lptr->len + LEVEL_HDRSIZE);
                    lptr++;
                    state = LTPRS_WAITNAME;
@@ -105,12 +106,13 @@ parse_ltree(const char *buf)
 
    if (state == LTPRS_WAITDELIM)
    {
-       finish_nodeitem(lptr, ptr, false, pos);
+       if (!finish_nodeitem(lptr, ptr, false, pos, escontext))
+           return NULL;
        totallen += MAXALIGN(lptr->len + LEVEL_HDRSIZE);
        lptr++;
    }
    else if (!(state == LTPRS_WAITNAME && lptr == list))
-       ereport(ERROR,
+       ereturn(escontext, NULL,
                (errcode(ERRCODE_SYNTAX_ERROR),
                 errmsg("ltree syntax error"),
                 errdetail("Unexpected end of input.")));
@@ -172,8 +174,12 @@ Datum
 ltree_in(PG_FUNCTION_ARGS)
 {
    char       *buf = (char *) PG_GETARG_POINTER(0);
+   ltree      *res;
 
-   PG_RETURN_POINTER(parse_ltree(buf));
+   if ((res = parse_ltree(buf, fcinfo->context)) == NULL)
+       PG_RETURN_NULL();
+
+   PG_RETURN_POINTER(res);
 }
 
 PG_FUNCTION_INFO_V1(ltree_out);
@@ -232,7 +238,7 @@ ltree_recv(PG_FUNCTION_ARGS)
        elog(ERROR, "unsupported ltree version number %d", version);
 
    str = pq_getmsgtext(buf, buf->len - buf->cursor, &nbytes);
-   res = parse_ltree(str);
+   res = parse_ltree(str, NULL);
    pfree(str);
 
    PG_RETURN_POINTER(res);
@@ -259,7 +265,7 @@ ltree_recv(PG_FUNCTION_ARGS)
  * returns an lquery
  */
 static lquery *
-parse_lquery(const char *buf)
+parse_lquery(const char *buf, struct Node *escontext)
 {
    const char *ptr;
    int         num = 0,
@@ -277,7 +283,7 @@ parse_lquery(const char *buf)
    int         charlen;
    int         pos = 1;        /* character position for error messages */
 
-#define UNCHAR ereport(ERROR, \
+#define UNCHAR ereturn(escontext, NULL,\
                       errcode(ERRCODE_SYNTAX_ERROR), \
                       errmsg("lquery syntax error at character %d", \
                              pos))
@@ -297,7 +303,7 @@ parse_lquery(const char *buf)
 
    num++;
    if (num > LQUERY_MAX_LEVELS)
-       ereport(ERROR,
+       ereturn(escontext, NULL,
                (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
                 errmsg("number of lquery items (%d) exceeds the maximum allowed (%d)",
                        num, LQUERY_MAX_LEVELS)));
@@ -361,18 +367,21 @@ parse_lquery(const char *buf)
                }
                else if (t_iseq(ptr, '|'))
                {
-                   finish_nodeitem(lptr, ptr, true, pos);
+                   if (!finish_nodeitem(lptr, ptr, true, pos, escontext))
+                       return NULL;
                    state = LQPRS_WAITVAR;
                }
                else if (t_iseq(ptr, '{'))
                {
-                   finish_nodeitem(lptr, ptr, true, pos);
+                   if (!finish_nodeitem(lptr, ptr, true, pos, escontext))
+                       return NULL;
                    curqlevel->flag |= LQL_COUNT;
                    state = LQPRS_WAITFNUM;
                }
                else if (t_iseq(ptr, '.'))
                {
-                   finish_nodeitem(lptr, ptr, true, pos);
+                   if (!finish_nodeitem(lptr, ptr, true, pos, escontext))
+                       return NULL;
                    state = LQPRS_WAITLEVEL;
                    curqlevel = NEXTLEV(curqlevel);
                }
@@ -407,7 +416,7 @@ parse_lquery(const char *buf)
                    int         low = atoi(ptr);
 
                    if (low < 0 || low > LTREE_MAX_LEVELS)
-                       ereport(ERROR,
+                       ereturn(escontext, NULL,
                                (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
                                 errmsg("lquery syntax error"),
                                 errdetail("Low limit (%d) exceeds the maximum allowed (%d), at character %d.",
@@ -425,13 +434,13 @@ parse_lquery(const char *buf)
                    int         high = atoi(ptr);
 
                    if (high < 0 || high > LTREE_MAX_LEVELS)
-                       ereport(ERROR,
+                       ereturn(escontext, NULL,
                                (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
                                 errmsg("lquery syntax error"),
                                 errdetail("High limit (%d) exceeds the maximum allowed (%d), at character %d.",
                                           high, LTREE_MAX_LEVELS, pos)));
                    else if (curqlevel->low > high)
-                       ereport(ERROR,
+                       ereturn(escontext, NULL,
                                (errcode(ERRCODE_SYNTAX_ERROR),
                                 errmsg("lquery syntax error"),
                                 errdetail("Low limit (%d) is greater than high limit (%d), at character %d.",
@@ -485,11 +494,14 @@ parse_lquery(const char *buf)
    }
 
    if (state == LQPRS_WAITDELIM)
-       finish_nodeitem(lptr, ptr, true, pos);
+   {
+       if (!finish_nodeitem(lptr, ptr, true, pos, escontext))
+           return false;
+   }
    else if (state == LQPRS_WAITOPEN)
        curqlevel->high = LTREE_MAX_LEVELS;
    else if (state != LQPRS_WAITEND)
-       ereport(ERROR,
+       ereturn(escontext, NULL,
                (errcode(ERRCODE_SYNTAX_ERROR),
                 errmsg("lquery syntax error"),
                 errdetail("Unexpected end of input.")));
@@ -569,8 +581,9 @@ parse_lquery(const char *buf)
  * Close out parsing an ltree or lquery nodeitem:
  * compute the correct length, and complain if it's not OK
  */
-static void
-finish_nodeitem(nodeitem *lptr, const char *ptr, bool is_lquery, int pos)
+static bool
+finish_nodeitem(nodeitem *lptr, const char *ptr, bool is_lquery, int pos,
+   struct Node *escontext)
 {
    if (is_lquery)
    {
@@ -591,18 +604,19 @@ finish_nodeitem(nodeitem *lptr, const char *ptr, bool is_lquery, int pos)
 
    /* Complain if it's empty or too long */
    if (lptr->len == 0)
-       ereport(ERROR,
+       ereturn(escontext, false,
                (errcode(ERRCODE_SYNTAX_ERROR),
                 is_lquery ?
                 errmsg("lquery syntax error at character %d", pos) :
                 errmsg("ltree syntax error at character %d", pos),
                 errdetail("Empty labels are not allowed.")));
    if (lptr->wlen > LTREE_LABEL_MAX_CHARS)
-       ereport(ERROR,
+       ereturn(escontext, false,
                (errcode(ERRCODE_NAME_TOO_LONG),
                 errmsg("label string is too long"),
                 errdetail("Label length is %d, must be at most %d, at character %d.",
                           lptr->wlen, LTREE_LABEL_MAX_CHARS, pos)));
+   return true;
 }
 
 /*
@@ -730,8 +744,12 @@ Datum
 lquery_in(PG_FUNCTION_ARGS)
 {
    char       *buf = (char *) PG_GETARG_POINTER(0);
+   lquery     *res;
+
+   if ((res = parse_lquery(buf, fcinfo->context)) == NULL)
+       PG_RETURN_NULL();
 
-   PG_RETURN_POINTER(parse_lquery(buf));
+   PG_RETURN_POINTER(res);
 }
 
 PG_FUNCTION_INFO_V1(lquery_out);
@@ -790,7 +808,7 @@ lquery_recv(PG_FUNCTION_ARGS)
        elog(ERROR, "unsupported lquery version number %d", version);
 
    str = pq_getmsgtext(buf, buf->len - buf->cursor, &nbytes);
-   res = parse_lquery(str);
+   res = parse_lquery(str, NULL);
    pfree(str);
 
    PG_RETURN_POINTER(res);
index 8ab0ce8e52b338386bc5e07d387d5eda5c66e867..a16e577303a54e7b61761d096895d5476d774cc5 100644 (file)
@@ -11,6 +11,7 @@
 #include "libpq/pqformat.h"
 #include "ltree.h"
 #include "miscadmin.h"
+#include "nodes/miscnodes.h"
 
 
 /* parser's states */
@@ -37,6 +38,7 @@ typedef struct
    char       *buf;
    int32       state;
    int32       count;
+   struct Node *escontext;
    /* reverse polish notation in list (for temporary usage) */
    NODE       *str;
    /* number in str */
@@ -51,6 +53,8 @@ typedef struct
 
 /*
  * get token from query string
+ *
+ * caller needs to check if a soft-error was set if the result is ERR.
  */
 static int32
 gettoken_query(QPRS_STATE *state, int32 *val, int32 *lenval, char **strval, uint16 *flag)
@@ -84,7 +88,7 @@ gettoken_query(QPRS_STATE *state, int32 *val, int32 *lenval, char **strval, uint
                    *flag = 0;
                }
                else if (!t_isspace(state->buf))
-                   ereport(ERROR,
+                   ereturn(state->escontext, ERR,
                            (errcode(ERRCODE_SYNTAX_ERROR),
                             errmsg("operand syntax error")));
                break;
@@ -92,7 +96,7 @@ gettoken_query(QPRS_STATE *state, int32 *val, int32 *lenval, char **strval, uint
                if (ISALNUM(state->buf))
                {
                    if (*flag)
-                       ereport(ERROR,
+                       ereturn(state->escontext, ERR,
                                (errcode(ERRCODE_SYNTAX_ERROR),
                                 errmsg("modifiers syntax error")));
                    *lenval += charlen;
@@ -124,9 +128,13 @@ gettoken_query(QPRS_STATE *state, int32 *val, int32 *lenval, char **strval, uint
                    return (state->count < 0) ? ERR : CLOSE;
                }
                else if (*(state->buf) == '\0')
+               {
                    return (state->count) ? ERR : END;
+               }
                else if (!t_iseq(state->buf, ' '))
+               {
                    return ERR;
+               }
                break;
            default:
                return ERR;
@@ -135,12 +143,14 @@ gettoken_query(QPRS_STATE *state, int32 *val, int32 *lenval, char **strval, uint
 
        state->buf += charlen;
    }
+
+   /* should not get here */
 }
 
 /*
  * push new one in polish notation reverse view
  */
-static void
+static bool
 pushquery(QPRS_STATE *state, int32 type, int32 val, int32 distance, int32 lenval, uint16 flag)
 {
    NODE       *tmp = (NODE *) palloc(sizeof(NODE));
@@ -149,11 +159,11 @@ pushquery(QPRS_STATE *state, int32 type, int32 val, int32 distance, int32 lenval
    tmp->val = val;
    tmp->flag = flag;
    if (distance > 0xffff)
-       ereport(ERROR,
+       ereturn(state->escontext, false,
                (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
                 errmsg("value is too big")));
    if (lenval > 0xff)
-       ereport(ERROR,
+       ereturn(state->escontext, false,
                (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
                 errmsg("operand is too long")));
    tmp->distance = distance;
@@ -161,21 +171,23 @@ pushquery(QPRS_STATE *state, int32 type, int32 val, int32 distance, int32 lenval
    tmp->next = state->str;
    state->str = tmp;
    state->num++;
+   return true;
 }
 
 /*
  * This function is used for query text parsing
  */
-static void
+static bool
 pushval_asis(QPRS_STATE *state, int type, char *strval, int lenval, uint16 flag)
 {
    if (lenval > 0xffff)
-       ereport(ERROR,
+       ereturn(state->escontext, false,
                (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
                 errmsg("word is too long")));
 
-   pushquery(state, type, ltree_crc32_sz(strval, lenval),
-             state->curop - state->op, lenval, flag);
+   if (! pushquery(state, type, ltree_crc32_sz(strval, lenval),
+                   state->curop - state->op, lenval, flag))
+       return false;
 
    while (state->curop - state->op + lenval + 1 >= state->lenop)
    {
@@ -190,6 +202,7 @@ pushval_asis(QPRS_STATE *state, int type, char *strval, int lenval, uint16 flag)
    *(state->curop) = '\0';
    state->curop++;
    state->sumlen += lenval + 1;
+   return true;
 }
 
 #define STACKDEPTH     32
@@ -215,17 +228,22 @@ makepol(QPRS_STATE *state)
        switch (type)
        {
            case VAL:
-               pushval_asis(state, VAL, strval, lenval, flag);
+               if (!pushval_asis(state, VAL, strval, lenval, flag))
+                   return ERR;
                while (lenstack && (stack[lenstack - 1] == (int32) '&' ||
                                    stack[lenstack - 1] == (int32) '!'))
                {
                    lenstack--;
-                   pushquery(state, OPR, stack[lenstack], 0, 0, 0);
+                   if (!pushquery(state, OPR, stack[lenstack], 0, 0, 0))
+                       return ERR;
                }
                break;
            case OPR:
                if (lenstack && val == (int32) '|')
-                   pushquery(state, OPR, val, 0, 0, 0);
+               {
+                   if (!pushquery(state, OPR, val, 0, 0, 0))
+                       return ERR;
+               }
                else
                {
                    if (lenstack == STACKDEPTH)
@@ -242,30 +260,35 @@ makepol(QPRS_STATE *state)
                                    stack[lenstack - 1] == (int32) '!'))
                {
                    lenstack--;
-                   pushquery(state, OPR, stack[lenstack], 0, 0, 0);
+                   if (!pushquery(state, OPR, stack[lenstack], 0, 0, 0))
+                       return ERR;
                }
                break;
            case CLOSE:
                while (lenstack)
                {
                    lenstack--;
-                   pushquery(state, OPR, stack[lenstack], 0, 0, 0);
+                   if (!pushquery(state, OPR, stack[lenstack], 0, 0, 0))
+                       return ERR;
                };
                return END;
                break;
            case ERR:
+               if (SOFT_ERROR_OCCURRED(state->escontext))
+                   return ERR;
+               /* fall through */
            default:
-               ereport(ERROR,
+               ereturn(state->escontext, ERR,
                        (errcode(ERRCODE_SYNTAX_ERROR),
                         errmsg("syntax error")));
 
-               return ERR;
        }
    }
    while (lenstack)
    {
        lenstack--;
-       pushquery(state, OPR, stack[lenstack], 0, 0, 0);
+       if (!pushquery(state, OPR, stack[lenstack], 0, 0, 0))
+           return ERR;
    };
    return END;
 }
@@ -304,7 +327,7 @@ findoprnd(ITEM *ptr, int32 *pos)
  * input
  */
 static ltxtquery *
-queryin(char *buf)
+queryin(char *buf, struct Node *escontext)
 {
    QPRS_STATE  state;
    int32       i;
@@ -325,6 +348,7 @@ queryin(char *buf)
    state.count = 0;
    state.num = 0;
    state.str = NULL;
+   state.escontext = escontext;
 
    /* init list of operand */
    state.sumlen = 0;
@@ -333,15 +357,16 @@ queryin(char *buf)
    *(state.curop) = '\0';
 
    /* parse query & make polish notation (postfix, but in reverse order) */
-   makepol(&state);
+   if (makepol(&state) == ERR)
+       return NULL;
    if (!state.num)
-       ereport(ERROR,
+       ereturn(escontext, NULL,
                (errcode(ERRCODE_SYNTAX_ERROR),
                 errmsg("syntax error"),
                 errdetail("Empty query.")));
 
    if (LTXTQUERY_TOO_BIG(state.num, state.sumlen))
-       ereport(ERROR,
+       ereturn(escontext, NULL,
                (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
                 errmsg("ltxtquery is too large")));
    commonlen = COMPUTESIZE(state.num, state.sumlen);
@@ -382,7 +407,11 @@ PG_FUNCTION_INFO_V1(ltxtq_in);
 Datum
 ltxtq_in(PG_FUNCTION_ARGS)
 {
-   PG_RETURN_POINTER(queryin((char *) PG_GETARG_POINTER(0)));
+   ltxtquery *res;
+
+   if ((res = queryin((char *) PG_GETARG_POINTER(0), fcinfo->context)) == NULL)
+       PG_RETURN_NULL();
+   PG_RETURN_POINTER(res);
 }
 
 /*
@@ -407,7 +436,7 @@ ltxtq_recv(PG_FUNCTION_ARGS)
        elog(ERROR, "unsupported ltxtquery version number %d", version);
 
    str = pq_getmsgtext(buf, buf->len - buf->cursor, &nbytes);
-   res = queryin(str);
+   res = queryin(str, NULL);
    pfree(str);
 
    PG_RETURN_POINTER(res);
index bf733ed17b9332d350d218b157ae86e55e2b522e..eabef4f851ce8c485ff203e11bdab6d1cdf6348e 100644 (file)
@@ -382,3 +382,18 @@ SELECT count(*) FROM _ltreetest WHERE t ~ '23.*{1}.1' ;
 SELECT count(*) FROM _ltreetest WHERE t ~ '23.*.1' ;
 SELECT count(*) FROM _ltreetest WHERE t ~ '23.*.2' ;
 SELECT count(*) FROM _ltreetest WHERE t ? '{23.*.1,23.*.2}' ;
+
+-- test non-error-throwing input
+
+SELECT str as "value", typ as "type",
+       pg_input_is_valid(str,typ) as ok,
+       pg_input_error_message(str,typ) as errmsg
+FROM (VALUES ('.2.3', 'ltree'),
+             ('1.2.', 'ltree'),
+             ('1.2.3','ltree'),
+             ('@.2.3','lquery'),
+             (' 2.3', 'lquery'),
+             ('1.2.3','lquery'),
+             ('$tree & aWdf@*','ltxtquery'),
+             ('!tree & aWdf@*','ltxtquery'))
+      AS a(str,typ);