Refactor pl_funcs.c to provide a usage-independent tree walker.
authorTom Lane <tgl@sss.pgh.pa.us>
Tue, 11 Feb 2025 17:14:12 +0000 (12:14 -0500)
committerTom Lane <tgl@sss.pgh.pa.us>
Tue, 11 Feb 2025 17:14:12 +0000 (12:14 -0500)
We haven't done this up to now because there was only one use-case,
namely plpgsql_free_function_memory's search for expressions to clean
up.  However an upcoming patch has another need for walking plpgsql
functions' statement trees, so let's create sharable tree-walker
infrastructure in the same style as expression_tree_walker().

This patch actually makes the code shorter, although that's
mainly down to having used a more compact coding style.  (I didn't
write a separate subroutine for each statement type, and I made
use of some newer notations like foreach_ptr.)

Author: Tom Lane <tgl@sss.pgh.pa.us>
Reviewed-by: Andrey Borodin <x4mmm@yandex-team.ru>
Reviewed-by: Pavel Borisov <pashkin.elfe@gmail.com>
Discussion: https://postgr.es/m/CACxu=vJaKFNsYxooSnW1wEgsAO5u_v1XYBacfVJ14wgJV_PYeg@mail.gmail.com

src/pl/plpgsql/src/pl_funcs.c

index 8c827fe5cc595becc2bde33608ad14cf857bba37..88e25b54bcd0a11cc030309726fe86434146da43 100644 (file)
@@ -334,387 +334,291 @@ plpgsql_getdiag_kindname(PLpgSQL_getdiag_kind kind)
 
 
 /**********************************************************************
- * Release memory when a PL/pgSQL function is no longer needed
+ * Support for recursing through a PL/pgSQL statement tree
  *
- * The code for recursing through the function tree is really only
- * needed to locate PLpgSQL_expr nodes, which may contain references
- * to saved SPI Plans that must be freed.  The function tree itself,
- * along with subsidiary data, is freed in one swoop by freeing the
- * function's permanent memory context.
+ * The point of this code is to encapsulate knowledge of where the
+ * sub-statements and expressions are in a statement tree, avoiding
+ * duplication of code.  The caller supplies two callbacks, one to
+ * be invoked on statements and one to be invoked on expressions.
+ * (The recursion should be started by invoking the statement callback
+ * on function->action.)  The statement callback should do any
+ * statement-type-specific action it needs, then recurse by calling
+ * plpgsql_statement_tree_walker().  The expression callback can be a
+ * no-op if no per-expression behavior is needed.
  **********************************************************************/
-static void free_stmt(PLpgSQL_stmt *stmt);
-static void free_block(PLpgSQL_stmt_block *block);
-static void free_assign(PLpgSQL_stmt_assign *stmt);
-static void free_if(PLpgSQL_stmt_if *stmt);
-static void free_case(PLpgSQL_stmt_case *stmt);
-static void free_loop(PLpgSQL_stmt_loop *stmt);
-static void free_while(PLpgSQL_stmt_while *stmt);
-static void free_fori(PLpgSQL_stmt_fori *stmt);
-static void free_fors(PLpgSQL_stmt_fors *stmt);
-static void free_forc(PLpgSQL_stmt_forc *stmt);
-static void free_foreach_a(PLpgSQL_stmt_foreach_a *stmt);
-static void free_exit(PLpgSQL_stmt_exit *stmt);
-static void free_return(PLpgSQL_stmt_return *stmt);
-static void free_return_next(PLpgSQL_stmt_return_next *stmt);
-static void free_return_query(PLpgSQL_stmt_return_query *stmt);
-static void free_raise(PLpgSQL_stmt_raise *stmt);
-static void free_assert(PLpgSQL_stmt_assert *stmt);
-static void free_execsql(PLpgSQL_stmt_execsql *stmt);
-static void free_dynexecute(PLpgSQL_stmt_dynexecute *stmt);
-static void free_dynfors(PLpgSQL_stmt_dynfors *stmt);
-static void free_getdiag(PLpgSQL_stmt_getdiag *stmt);
-static void free_open(PLpgSQL_stmt_open *stmt);
-static void free_fetch(PLpgSQL_stmt_fetch *stmt);
-static void free_close(PLpgSQL_stmt_close *stmt);
-static void free_perform(PLpgSQL_stmt_perform *stmt);
-static void free_call(PLpgSQL_stmt_call *stmt);
-static void free_commit(PLpgSQL_stmt_commit *stmt);
-static void free_rollback(PLpgSQL_stmt_rollback *stmt);
-static void free_expr(PLpgSQL_expr *expr);
+typedef void (*plpgsql_stmt_walker_callback) (PLpgSQL_stmt *stmt,
+                                             void *context);
+typedef void (*plpgsql_expr_walker_callback) (PLpgSQL_expr *expr,
+                                             void *context);
 
+/*
+ * As in nodeFuncs.h, we respectfully decline to support the C standard's
+ * position that a pointer to struct is incompatible with "void *".  Instead,
+ * silence related compiler warnings using casts in this macro wrapper.
+ */
+#define plpgsql_statement_tree_walker(s, sw, ew, c) \
+   plpgsql_statement_tree_walker_impl(s, (plpgsql_stmt_walker_callback) (sw), \
+                                      (plpgsql_expr_walker_callback) (ew), c)
 
 static void
-free_stmt(PLpgSQL_stmt *stmt)
+plpgsql_statement_tree_walker_impl(PLpgSQL_stmt *stmt,
+                                  plpgsql_stmt_walker_callback stmt_callback,
+                                  plpgsql_expr_walker_callback expr_callback,
+                                  void *context)
 {
+#define S_WALK(st) stmt_callback(st, context)
+#define E_WALK(ex) expr_callback(ex, context)
+#define S_LIST_WALK(lst) foreach_ptr(PLpgSQL_stmt, st, lst) S_WALK(st)
+#define E_LIST_WALK(lst) foreach_ptr(PLpgSQL_expr, ex, lst) E_WALK(ex)
+
    switch (stmt->cmd_type)
    {
        case PLPGSQL_STMT_BLOCK:
-           free_block((PLpgSQL_stmt_block *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_block *bstmt = (PLpgSQL_stmt_block *) stmt;
+
+               S_LIST_WALK(bstmt->body);
+               if (bstmt->exceptions)
+               {
+                   foreach_ptr(PLpgSQL_exception, exc, bstmt->exceptions->exc_list)
+                   {
+                       /* conditions list has no interesting sub-structure */
+                       S_LIST_WALK(exc->action);
+                   }
+               }
+               break;
+           }
        case PLPGSQL_STMT_ASSIGN:
-           free_assign((PLpgSQL_stmt_assign *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_assign *astmt = (PLpgSQL_stmt_assign *) stmt;
+
+               E_WALK(astmt->expr);
+               break;
+           }
        case PLPGSQL_STMT_IF:
-           free_if((PLpgSQL_stmt_if *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_if *ifstmt = (PLpgSQL_stmt_if *) stmt;
+
+               E_WALK(ifstmt->cond);
+               S_LIST_WALK(ifstmt->then_body);
+               foreach_ptr(PLpgSQL_if_elsif, elif, ifstmt->elsif_list)
+               {
+                   E_WALK(elif->cond);
+                   S_LIST_WALK(elif->stmts);
+               }
+               S_LIST_WALK(ifstmt->else_body);
+               break;
+           }
        case PLPGSQL_STMT_CASE:
-           free_case((PLpgSQL_stmt_case *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_case *cstmt = (PLpgSQL_stmt_case *) stmt;
+
+               E_WALK(cstmt->t_expr);
+               foreach_ptr(PLpgSQL_case_when, cwt, cstmt->case_when_list)
+               {
+                   E_WALK(cwt->expr);
+                   S_LIST_WALK(cwt->stmts);
+               }
+               S_LIST_WALK(cstmt->else_stmts);
+               break;
+           }
        case PLPGSQL_STMT_LOOP:
-           free_loop((PLpgSQL_stmt_loop *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_loop *lstmt = (PLpgSQL_stmt_loop *) stmt;
+
+               S_LIST_WALK(lstmt->body);
+               break;
+           }
        case PLPGSQL_STMT_WHILE:
-           free_while((PLpgSQL_stmt_while *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_while *wstmt = (PLpgSQL_stmt_while *) stmt;
+
+               E_WALK(wstmt->cond);
+               S_LIST_WALK(wstmt->body);
+               break;
+           }
        case PLPGSQL_STMT_FORI:
-           free_fori((PLpgSQL_stmt_fori *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_fori *fori = (PLpgSQL_stmt_fori *) stmt;
+
+               E_WALK(fori->lower);
+               E_WALK(fori->upper);
+               E_WALK(fori->step);
+               S_LIST_WALK(fori->body);
+               break;
+           }
        case PLPGSQL_STMT_FORS:
-           free_fors((PLpgSQL_stmt_fors *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_fors *fors = (PLpgSQL_stmt_fors *) stmt;
+
+               S_LIST_WALK(fors->body);
+               E_WALK(fors->query);
+               break;
+           }
        case PLPGSQL_STMT_FORC:
-           free_forc((PLpgSQL_stmt_forc *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_forc *forc = (PLpgSQL_stmt_forc *) stmt;
+
+               S_LIST_WALK(forc->body);
+               E_WALK(forc->argquery);
+               break;
+           }
        case PLPGSQL_STMT_FOREACH_A:
-           free_foreach_a((PLpgSQL_stmt_foreach_a *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_foreach_a *fstmt = (PLpgSQL_stmt_foreach_a *) stmt;
+
+               E_WALK(fstmt->expr);
+               S_LIST_WALK(fstmt->body);
+               break;
+           }
        case PLPGSQL_STMT_EXIT:
-           free_exit((PLpgSQL_stmt_exit *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_exit *estmt = (PLpgSQL_stmt_exit *) stmt;
+
+               E_WALK(estmt->cond);
+               break;
+           }
        case PLPGSQL_STMT_RETURN:
-           free_return((PLpgSQL_stmt_return *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_return *rstmt = (PLpgSQL_stmt_return *) stmt;
+
+               E_WALK(rstmt->expr);
+               break;
+           }
        case PLPGSQL_STMT_RETURN_NEXT:
-           free_return_next((PLpgSQL_stmt_return_next *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_return_next *rstmt = (PLpgSQL_stmt_return_next *) stmt;
+
+               E_WALK(rstmt->expr);
+               break;
+           }
        case PLPGSQL_STMT_RETURN_QUERY:
-           free_return_query((PLpgSQL_stmt_return_query *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_return_query *rstmt = (PLpgSQL_stmt_return_query *) stmt;
+
+               E_WALK(rstmt->query);
+               E_WALK(rstmt->dynquery);
+               E_LIST_WALK(rstmt->params);
+               break;
+           }
        case PLPGSQL_STMT_RAISE:
-           free_raise((PLpgSQL_stmt_raise *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_raise *rstmt = (PLpgSQL_stmt_raise *) stmt;
+
+               E_LIST_WALK(rstmt->params);
+               foreach_ptr(PLpgSQL_raise_option, opt, rstmt->options)
+               {
+                   E_WALK(opt->expr);
+               }
+               break;
+           }
        case PLPGSQL_STMT_ASSERT:
-           free_assert((PLpgSQL_stmt_assert *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_assert *astmt = (PLpgSQL_stmt_assert *) stmt;
+
+               E_WALK(astmt->cond);
+               E_WALK(astmt->message);
+               break;
+           }
        case PLPGSQL_STMT_EXECSQL:
-           free_execsql((PLpgSQL_stmt_execsql *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_execsql *xstmt = (PLpgSQL_stmt_execsql *) stmt;
+
+               E_WALK(xstmt->sqlstmt);
+               break;
+           }
        case PLPGSQL_STMT_DYNEXECUTE:
-           free_dynexecute((PLpgSQL_stmt_dynexecute *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_dynexecute *dstmt = (PLpgSQL_stmt_dynexecute *) stmt;
+
+               E_WALK(dstmt->query);
+               E_LIST_WALK(dstmt->params);
+               break;
+           }
        case PLPGSQL_STMT_DYNFORS:
-           free_dynfors((PLpgSQL_stmt_dynfors *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_dynfors *dstmt = (PLpgSQL_stmt_dynfors *) stmt;
+
+               S_LIST_WALK(dstmt->body);
+               E_WALK(dstmt->query);
+               E_LIST_WALK(dstmt->params);
+               break;
+           }
        case PLPGSQL_STMT_GETDIAG:
-           free_getdiag((PLpgSQL_stmt_getdiag *) stmt);
-           break;
+           {
+               /* no interesting sub-structure */
+               break;
+           }
        case PLPGSQL_STMT_OPEN:
-           free_open((PLpgSQL_stmt_open *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_open *ostmt = (PLpgSQL_stmt_open *) stmt;
+
+               E_WALK(ostmt->argquery);
+               E_WALK(ostmt->query);
+               E_WALK(ostmt->dynquery);
+               E_LIST_WALK(ostmt->params);
+               break;
+           }
        case PLPGSQL_STMT_FETCH:
-           free_fetch((PLpgSQL_stmt_fetch *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_fetch *fstmt = (PLpgSQL_stmt_fetch *) stmt;
+
+               E_WALK(fstmt->expr);
+               break;
+           }
        case PLPGSQL_STMT_CLOSE:
-           free_close((PLpgSQL_stmt_close *) stmt);
-           break;
+           {
+               /* no interesting sub-structure */
+               break;
+           }
        case PLPGSQL_STMT_PERFORM:
-           free_perform((PLpgSQL_stmt_perform *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_perform *pstmt = (PLpgSQL_stmt_perform *) stmt;
+
+               E_WALK(pstmt->expr);
+               break;
+           }
        case PLPGSQL_STMT_CALL:
-           free_call((PLpgSQL_stmt_call *) stmt);
-           break;
+           {
+               PLpgSQL_stmt_call *cstmt = (PLpgSQL_stmt_call *) stmt;
+
+               E_WALK(cstmt->expr);
+               break;
+           }
        case PLPGSQL_STMT_COMMIT:
-           free_commit((PLpgSQL_stmt_commit *) stmt);
-           break;
        case PLPGSQL_STMT_ROLLBACK:
-           free_rollback((PLpgSQL_stmt_rollback *) stmt);
-           break;
+           {
+               /* no interesting sub-structure */
+               break;
+           }
        default:
            elog(ERROR, "unrecognized cmd_type: %d", stmt->cmd_type);
            break;
    }
 }
 
-static void
-free_stmts(List *stmts)
-{
-   ListCell   *s;
-
-   foreach(s, stmts)
-   {
-       free_stmt((PLpgSQL_stmt *) lfirst(s));
-   }
-}
-
-static void
-free_block(PLpgSQL_stmt_block *block)
-{
-   free_stmts(block->body);
-   if (block->exceptions)
-   {
-       ListCell   *e;
-
-       foreach(e, block->exceptions->exc_list)
-       {
-           PLpgSQL_exception *exc = (PLpgSQL_exception *) lfirst(e);
-
-           free_stmts(exc->action);
-       }
-   }
-}
-
-static void
-free_assign(PLpgSQL_stmt_assign *stmt)
-{
-   free_expr(stmt->expr);
-}
-
-static void
-free_if(PLpgSQL_stmt_if *stmt)
-{
-   ListCell   *l;
-
-   free_expr(stmt->cond);
-   free_stmts(stmt->then_body);
-   foreach(l, stmt->elsif_list)
-   {
-       PLpgSQL_if_elsif *elif = (PLpgSQL_if_elsif *) lfirst(l);
-
-       free_expr(elif->cond);
-       free_stmts(elif->stmts);
-   }
-   free_stmts(stmt->else_body);
-}
-
-static void
-free_case(PLpgSQL_stmt_case *stmt)
-{
-   ListCell   *l;
-
-   free_expr(stmt->t_expr);
-   foreach(l, stmt->case_when_list)
-   {
-       PLpgSQL_case_when *cwt = (PLpgSQL_case_when *) lfirst(l);
-
-       free_expr(cwt->expr);
-       free_stmts(cwt->stmts);
-   }
-   free_stmts(stmt->else_stmts);
-}
-
-static void
-free_loop(PLpgSQL_stmt_loop *stmt)
-{
-   free_stmts(stmt->body);
-}
-
-static void
-free_while(PLpgSQL_stmt_while *stmt)
-{
-   free_expr(stmt->cond);
-   free_stmts(stmt->body);
-}
-
-static void
-free_fori(PLpgSQL_stmt_fori *stmt)
-{
-   free_expr(stmt->lower);
-   free_expr(stmt->upper);
-   free_expr(stmt->step);
-   free_stmts(stmt->body);
-}
-
-static void
-free_fors(PLpgSQL_stmt_fors *stmt)
-{
-   free_stmts(stmt->body);
-   free_expr(stmt->query);
-}
-
-static void
-free_forc(PLpgSQL_stmt_forc *stmt)
-{
-   free_stmts(stmt->body);
-   free_expr(stmt->argquery);
-}
-
-static void
-free_foreach_a(PLpgSQL_stmt_foreach_a *stmt)
-{
-   free_expr(stmt->expr);
-   free_stmts(stmt->body);
-}
-
-static void
-free_open(PLpgSQL_stmt_open *stmt)
-{
-   ListCell   *lc;
-
-   free_expr(stmt->argquery);
-   free_expr(stmt->query);
-   free_expr(stmt->dynquery);
-   foreach(lc, stmt->params)
-   {
-       free_expr((PLpgSQL_expr *) lfirst(lc));
-   }
-}
-
-static void
-free_fetch(PLpgSQL_stmt_fetch *stmt)
-{
-   free_expr(stmt->expr);
-}
 
-static void
-free_close(PLpgSQL_stmt_close *stmt)
-{
-}
-
-static void
-free_perform(PLpgSQL_stmt_perform *stmt)
-{
-   free_expr(stmt->expr);
-}
-
-static void
-free_call(PLpgSQL_stmt_call *stmt)
-{
-   free_expr(stmt->expr);
-}
-
-static void
-free_commit(PLpgSQL_stmt_commit *stmt)
-{
-}
-
-static void
-free_rollback(PLpgSQL_stmt_rollback *stmt)
-{
-}
-
-static void
-free_exit(PLpgSQL_stmt_exit *stmt)
-{
-   free_expr(stmt->cond);
-}
-
-static void
-free_return(PLpgSQL_stmt_return *stmt)
-{
-   free_expr(stmt->expr);
-}
-
-static void
-free_return_next(PLpgSQL_stmt_return_next *stmt)
-{
-   free_expr(stmt->expr);
-}
-
-static void
-free_return_query(PLpgSQL_stmt_return_query *stmt)
-{
-   ListCell   *lc;
-
-   free_expr(stmt->query);
-   free_expr(stmt->dynquery);
-   foreach(lc, stmt->params)
-   {
-       free_expr((PLpgSQL_expr *) lfirst(lc));
-   }
-}
-
-static void
-free_raise(PLpgSQL_stmt_raise *stmt)
-{
-   ListCell   *lc;
-
-   foreach(lc, stmt->params)
-   {
-       free_expr((PLpgSQL_expr *) lfirst(lc));
-   }
-   foreach(lc, stmt->options)
-   {
-       PLpgSQL_raise_option *opt = (PLpgSQL_raise_option *) lfirst(lc);
-
-       free_expr(opt->expr);
-   }
-}
-
-static void
-free_assert(PLpgSQL_stmt_assert *stmt)
-{
-   free_expr(stmt->cond);
-   free_expr(stmt->message);
-}
-
-static void
-free_execsql(PLpgSQL_stmt_execsql *stmt)
-{
-   free_expr(stmt->sqlstmt);
-}
-
-static void
-free_dynexecute(PLpgSQL_stmt_dynexecute *stmt)
-{
-   ListCell   *lc;
-
-   free_expr(stmt->query);
-   foreach(lc, stmt->params)
-   {
-       free_expr((PLpgSQL_expr *) lfirst(lc));
-   }
-}
-
-static void
-free_dynfors(PLpgSQL_stmt_dynfors *stmt)
-{
-   ListCell   *lc;
-
-   free_stmts(stmt->body);
-   free_expr(stmt->query);
-   foreach(lc, stmt->params)
-   {
-       free_expr((PLpgSQL_expr *) lfirst(lc));
-   }
-}
+/**********************************************************************
+ * Release memory when a PL/pgSQL function is no longer needed
+ *
+ * This code only needs to deal with cleaning up PLpgSQL_expr nodes,
+ * which may contain references to saved SPI Plans that must be freed.
+ * The function tree itself, along with subsidiary data, is freed in
+ * one swoop by freeing the function's permanent memory context.
+ **********************************************************************/
+static void free_stmt(PLpgSQL_stmt *stmt, void *context);
+static void free_expr(PLpgSQL_expr *expr, void *context);
 
 static void
-free_getdiag(PLpgSQL_stmt_getdiag *stmt)
+free_stmt(PLpgSQL_stmt *stmt, void *context)
 {
+   if (stmt == NULL)
+       return;
+   plpgsql_statement_tree_walker(stmt, free_stmt, free_expr, NULL);
 }
 
 static void
-free_expr(PLpgSQL_expr *expr)
+free_expr(PLpgSQL_expr *expr, void *context)
 {
    if (expr && expr->plan)
    {
@@ -743,8 +647,8 @@ plpgsql_free_function_memory(PLpgSQL_function *func)
                {
                    PLpgSQL_var *var = (PLpgSQL_var *) d;
 
-                   free_expr(var->default_val);
-                   free_expr(var->cursor_explicit_expr);
+                   free_expr(var->default_val, NULL);
+                   free_expr(var->cursor_explicit_expr, NULL);
                }
                break;
            case PLPGSQL_DTYPE_ROW:
@@ -753,7 +657,7 @@ plpgsql_free_function_memory(PLpgSQL_function *func)
                {
                    PLpgSQL_rec *rec = (PLpgSQL_rec *) d;
 
-                   free_expr(rec->default_val);
+                   free_expr(rec->default_val, NULL);
                }
                break;
            case PLPGSQL_DTYPE_RECFIELD:
@@ -765,8 +669,7 @@ plpgsql_free_function_memory(PLpgSQL_function *func)
    func->ndatums = 0;
 
    /* Release plans in statement tree */
-   if (func->action)
-       free_block(func->action);
+   free_stmt((PLpgSQL_stmt *) func->action, NULL);
    func->action = NULL;
 
    /*
@@ -782,6 +685,9 @@ plpgsql_free_function_memory(PLpgSQL_function *func)
 
 /**********************************************************************
  * Debug functions for analyzing the compiled code
+ *
+ * Sadly, there doesn't seem to be any way to let plpgsql_statement_tree_walker
+ * bear some of the burden for this.
  **********************************************************************/
 static int dump_indent;