Consolidate and improve checking of key-column-attnum arguments for
authorTom Lane <tgl@sss.pgh.pa.us>
Tue, 15 Jun 2010 16:22:39 +0000 (16:22 +0000)
committerTom Lane <tgl@sss.pgh.pa.us>
Tue, 15 Jun 2010 16:22:39 +0000 (16:22 +0000)
dblink_build_sql_insert() and related functions.  In particular, be sure to
reject references to dropped and out-of-range column numbers.  The numbers
are still interpreted as physical column numbers, though, for backward
compatibility.

This patch replaces Joe's patch of 2010-02-03, which handled only some aspects
of the problem.

contrib/dblink/dblink.c
contrib/dblink/expected/dblink.out

index 5c3b59105fe8db9e1a8a15c73af4915d7af93ea2..b03b16a139a51ef7001ed0bc445cb2500e35c3ec 100644 (file)
@@ -8,7 +8,7 @@
  * Darko Prenosil <Darko.Prenosil@finteh.hr>
  * Shridhar Daithankar <shridhar_daithankar@persistent.co.in>
  *
- * $PostgreSQL: pgsql/contrib/dblink/dblink.c,v 1.60.2.8 2010/06/14 20:49:51 tgl Exp $
+ * $PostgreSQL: pgsql/contrib/dblink/dblink.c,v 1.60.2.9 2010/06/15 16:22:39 tgl Exp $
  * Copyright (c) 2001-2006, PostgreSQL Global Development Group
  * ALL RIGHTS RESERVED;
  *
@@ -83,18 +83,20 @@ static void createNewConnection(const char *name, remoteConn * rconn);
 static void deleteConnection(const char *name);
 static char **get_pkey_attnames(Relation rel, int16 *numatts);
 static char **get_text_array_contents(ArrayType *array, int *numitems);
-static char *get_sql_insert(Relation rel, int2vector *pkattnums, int16 pknumatts, char **src_pkattvals, char **tgt_pkattvals);
-static char *get_sql_delete(Relation rel, int2vector *pkattnums, int16 pknumatts, char **tgt_pkattvals);
-static char *get_sql_update(Relation rel, int2vector *pkattnums, int16 pknumatts, char **src_pkattvals, char **tgt_pkattvals);
+static char *get_sql_insert(Relation rel, int *pkattnums, int pknumatts, char **src_pkattvals, char **tgt_pkattvals);
+static char *get_sql_delete(Relation rel, int *pkattnums, int pknumatts, char **tgt_pkattvals);
+static char *get_sql_update(Relation rel, int *pkattnums, int pknumatts, char **src_pkattvals, char **tgt_pkattvals);
 static char *quote_literal_cstr(char *rawstr);
 static char *quote_ident_cstr(char *rawstr);
-static int16 get_attnum_pk_pos(int2vector *pkattnums, int16 pknumatts, int16 key);
-static HeapTuple get_tuple_of_interest(Relation rel, int2vector *pkattnums, int16 pknumatts, char **src_pkattvals);
+static int get_attnum_pk_pos(int *pkattnums, int pknumatts, int key);
+static HeapTuple get_tuple_of_interest(Relation rel, int *pkattnums, int pknumatts, char **src_pkattvals);
 static Relation get_rel_from_relname(text *relname_text, LOCKMODE lockmode, AclMode aclmode);
 static char *generate_relation_name(Relation rel);
 static char *connstr_strip_password(const char *connstr);
 static void dblink_security_check(PGconn *conn, remoteConn *rconn, const char *connstr);
-static int get_nondropped_natts(Relation rel);
+static void validate_pkattnums(Relation rel,
+                  int2vector *pkattnums_arg, int32 pknumatts_arg,
+                  int **pkattnums, int *pknumatts);
 
 /* Global */
 static remoteConn *pconn = NULL;
@@ -1377,18 +1379,18 @@ Datum
 dblink_build_sql_insert(PG_FUNCTION_ARGS)
 {
    text       *relname_text = PG_GETARG_TEXT_P(0);
-   int2vector *pkattnums = (int2vector *) PG_GETARG_POINTER(1);
-   int32       pknumatts_tmp = PG_GETARG_INT32(2);
+   int2vector *pkattnums_arg = (int2vector *) PG_GETARG_POINTER(1);
+   int32       pknumatts_arg = PG_GETARG_INT32(2);
    ArrayType  *src_pkattvals_arry = PG_GETARG_ARRAYTYPE_P(3);
    ArrayType  *tgt_pkattvals_arry = PG_GETARG_ARRAYTYPE_P(4);
    Relation    rel;
-   int16       pknumatts = 0;
+   int        *pkattnums;
+   int         pknumatts;
    char      **src_pkattvals;
    char      **tgt_pkattvals;
    int         src_nitems;
    int         tgt_nitems;
    char       *sql;
-   int         nondropped_natts;
 
    /*
     * Open target relation.
@@ -1396,29 +1398,10 @@ dblink_build_sql_insert(PG_FUNCTION_ARGS)
    rel = get_rel_from_relname(relname_text, AccessShareLock, ACL_SELECT);
 
    /*
-    * There should be at least one key attribute
+    * Process pkattnums argument.
     */
-   if (pknumatts_tmp <= 0)
-       ereport(ERROR,
-               (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
-                errmsg("number of key attributes must be > 0")));
-
-   if (pknumatts_tmp <= SHRT_MAX)
-       pknumatts = pknumatts_tmp;
-   else
-       ereport(ERROR,
-               (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
-                errmsg("input for number of primary key " \
-                       "attributes too large")));
-
-   /*
-    * ensure we don't ask for more pk attributes than we have
-    * non-dropped columns
-    */
-   nondropped_natts = get_nondropped_natts(rel);
-   if (pknumatts > nondropped_natts)
-       ereport(ERROR, (errcode(ERRCODE_SYNTAX_ERROR),
-               errmsg("number of primary key fields exceeds number of specified relation attributes")));
+   validate_pkattnums(rel, pkattnums_arg, pknumatts_arg,
+                      &pkattnums, &pknumatts);
 
    /*
     * Source array is made up of key values that will be used to locate the
@@ -1487,12 +1470,12 @@ Datum
 dblink_build_sql_delete(PG_FUNCTION_ARGS)
 {
    text       *relname_text = PG_GETARG_TEXT_P(0);
-   int2vector *pkattnums = (int2vector *) PG_GETARG_POINTER(1);
-   int32       pknumatts_tmp = PG_GETARG_INT32(2);
+   int2vector *pkattnums_arg = (int2vector *) PG_GETARG_POINTER(1);
+   int32       pknumatts_arg = PG_GETARG_INT32(2);
    ArrayType  *tgt_pkattvals_arry = PG_GETARG_ARRAYTYPE_P(3);
-   int         nondropped_natts;
    Relation    rel;
-   int16       pknumatts = 0;
+   int        *pkattnums;
+   int         pknumatts;
    char      **tgt_pkattvals;
    int         tgt_nitems;
    char       *sql;
@@ -1503,29 +1486,10 @@ dblink_build_sql_delete(PG_FUNCTION_ARGS)
    rel = get_rel_from_relname(relname_text, AccessShareLock, ACL_SELECT);
 
    /*
-    * There should be at least one key attribute
-    */
-   if (pknumatts_tmp <= 0)
-       ereport(ERROR,
-               (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
-                errmsg("number of key attributes must be > 0")));
-
-   if (pknumatts_tmp <= SHRT_MAX)
-       pknumatts = pknumatts_tmp;
-   else
-       ereport(ERROR,
-               (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
-                errmsg("input for number of primary key " \
-                       "attributes too large")));
-
-   /*
-    * ensure we don't ask for more pk attributes than we have
-    * non-dropped columns
+    * Process pkattnums argument.
     */
-   nondropped_natts = get_nondropped_natts(rel);
-   if (pknumatts > nondropped_natts)
-       ereport(ERROR, (errcode(ERRCODE_SYNTAX_ERROR),
-               errmsg("number of primary key fields exceeds number of specified relation attributes")));
+   validate_pkattnums(rel, pkattnums_arg, pknumatts_arg,
+                      &pkattnums, &pknumatts);
 
    /*
     * Target array is made up of key values that will be used to build the
@@ -1583,13 +1547,13 @@ Datum
 dblink_build_sql_update(PG_FUNCTION_ARGS)
 {
    text       *relname_text = PG_GETARG_TEXT_P(0);
-   int2vector *pkattnums = (int2vector *) PG_GETARG_POINTER(1);
-   int32       pknumatts_tmp = PG_GETARG_INT32(2);
+   int2vector *pkattnums_arg = (int2vector *) PG_GETARG_POINTER(1);
+   int32       pknumatts_arg = PG_GETARG_INT32(2);
    ArrayType  *src_pkattvals_arry = PG_GETARG_ARRAYTYPE_P(3);
    ArrayType  *tgt_pkattvals_arry = PG_GETARG_ARRAYTYPE_P(4);
-   int         nondropped_natts;
    Relation    rel;
-   int16       pknumatts = 0;
+   int        *pkattnums;
+   int         pknumatts;
    char      **src_pkattvals;
    char      **tgt_pkattvals;
    int         src_nitems;
@@ -1602,29 +1566,10 @@ dblink_build_sql_update(PG_FUNCTION_ARGS)
    rel = get_rel_from_relname(relname_text, AccessShareLock, ACL_SELECT);
 
    /*
-    * There should be one source array key values for each key attnum
+    * Process pkattnums argument.
     */
-   if (pknumatts_tmp <= 0)
-       ereport(ERROR,
-               (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
-                errmsg("number of key attributes must be > 0")));
-
-   if (pknumatts_tmp <= SHRT_MAX)
-       pknumatts = pknumatts_tmp;
-   else
-       ereport(ERROR,
-               (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
-                errmsg("input for number of primary key " \
-                       "attributes too large")));
-
-   /*
-    * ensure we don't ask for more pk attributes than we have
-    * non-dropped columns
-    */
-   nondropped_natts = get_nondropped_natts(rel);
-   if (pknumatts > nondropped_natts)
-       ereport(ERROR, (errcode(ERRCODE_SYNTAX_ERROR),
-               errmsg("number of primary key fields exceeds number of specified relation attributes")));
+   validate_pkattnums(rel, pkattnums_arg, pknumatts_arg,
+                      &pkattnums, &pknumatts);
 
    /*
     * Source array is made up of key values that will be used to locate the
@@ -1810,7 +1755,7 @@ get_text_array_contents(ArrayType *array, int *numitems)
 }
 
 static char *
-get_sql_insert(Relation rel, int2vector *pkattnums, int16 pknumatts, char **src_pkattvals, char **tgt_pkattvals)
+get_sql_insert(Relation rel, int *pkattnums, int pknumatts, char **src_pkattvals, char **tgt_pkattvals)
 {
    char       *relname;
    HeapTuple   tuple;
@@ -1818,7 +1763,7 @@ get_sql_insert(Relation rel, int2vector *pkattnums, int16 pknumatts, char **src_
    int         natts;
    StringInfoData buf;
    char       *val;
-   int16       key;
+   int         key;
    int         i;
    bool        needComma;
 
@@ -1867,7 +1812,7 @@ get_sql_insert(Relation rel, int2vector *pkattnums, int16 pknumatts, char **src_
            appendStringInfo(&buf, ",");
 
        if (tgt_pkattvals != NULL)
-           key = get_attnum_pk_pos(pkattnums, pknumatts, i + 1);
+           key = get_attnum_pk_pos(pkattnums, pknumatts, i);
        else
            key = -1;
 
@@ -1891,11 +1836,10 @@ get_sql_insert(Relation rel, int2vector *pkattnums, int16 pknumatts, char **src_
 }
 
 static char *
-get_sql_delete(Relation rel, int2vector *pkattnums, int16 pknumatts, char **tgt_pkattvals)
+get_sql_delete(Relation rel, int *pkattnums, int pknumatts, char **tgt_pkattvals)
 {
    char       *relname;
    TupleDesc   tupdesc;
-   int         natts;
    StringInfoData buf;
    int         i;
 
@@ -1905,18 +1849,17 @@ get_sql_delete(Relation rel, int2vector *pkattnums, int16 pknumatts, char **tgt_
    relname = generate_relation_name(rel);
 
    tupdesc = rel->rd_att;
-   natts = tupdesc->natts;
 
    appendStringInfo(&buf, "DELETE FROM %s WHERE ", relname);
    for (i = 0; i < pknumatts; i++)
    {
-       int16       pkattnum = pkattnums->values[i];
+       int         pkattnum = pkattnums[i];
 
        if (i > 0)
            appendStringInfo(&buf, " AND ");
 
        appendStringInfoString(&buf,
-          quote_ident_cstr(NameStr(tupdesc->attrs[pkattnum - 1]->attname)));
+          quote_ident_cstr(NameStr(tupdesc->attrs[pkattnum]->attname)));
 
        if (tgt_pkattvals == NULL)
            /* internal error */
@@ -1933,7 +1876,7 @@ get_sql_delete(Relation rel, int2vector *pkattnums, int16 pknumatts, char **tgt_
 }
 
 static char *
-get_sql_update(Relation rel, int2vector *pkattnums, int16 pknumatts, char **src_pkattvals, char **tgt_pkattvals)
+get_sql_update(Relation rel, int *pkattnums, int pknumatts, char **src_pkattvals, char **tgt_pkattvals)
 {
    char       *relname;
    HeapTuple   tuple;
@@ -1941,7 +1884,7 @@ get_sql_update(Relation rel, int2vector *pkattnums, int16 pknumatts, char **src_
    int         natts;
    StringInfoData buf;
    char       *val;
-   int16       key;
+   int         key;
    int         i;
    bool        needComma;
 
@@ -1974,7 +1917,7 @@ get_sql_update(Relation rel, int2vector *pkattnums, int16 pknumatts, char **src_
                      quote_ident_cstr(NameStr(tupdesc->attrs[i]->attname)));
 
        if (tgt_pkattvals != NULL)
-           key = get_attnum_pk_pos(pkattnums, pknumatts, i + 1);
+           key = get_attnum_pk_pos(pkattnums, pknumatts, i);
        else
            key = -1;
 
@@ -1997,18 +1940,18 @@ get_sql_update(Relation rel, int2vector *pkattnums, int16 pknumatts, char **src_
 
    for (i = 0; i < pknumatts; i++)
    {
-       int16       pkattnum = pkattnums->values[i];
+       int         pkattnum = pkattnums[i];
 
        if (i > 0)
            appendStringInfo(&buf, " AND ");
 
        appendStringInfo(&buf, "%s",
-          quote_ident_cstr(NameStr(tupdesc->attrs[pkattnum - 1]->attname)));
+          quote_ident_cstr(NameStr(tupdesc->attrs[pkattnum]->attname)));
 
        if (tgt_pkattvals != NULL)
            val = tgt_pkattvals[i] ? pstrdup(tgt_pkattvals[i]) : NULL;
        else
-           val = SPI_getvalue(tuple, tupdesc, pkattnum);
+           val = SPI_getvalue(tuple, tupdesc, pkattnum + 1);
 
        if (val != NULL)
        {
@@ -2058,8 +2001,8 @@ quote_ident_cstr(char *rawstr)
    return result;
 }
 
-static int16
-get_attnum_pk_pos(int2vector *pkattnums, int16 pknumatts, int16 key)
+static int
+get_attnum_pk_pos(int *pkattnums, int pknumatts, int key)
 {
    int         i;
 
@@ -2067,14 +2010,14 @@ get_attnum_pk_pos(int2vector *pkattnums, int16 pknumatts, int16 key)
     * Not likely a long list anyway, so just scan for the value
     */
    for (i = 0; i < pknumatts; i++)
-       if (key == pkattnums->values[i])
+       if (key == pkattnums[i])
            return i;
 
    return -1;
 }
 
 static HeapTuple
-get_tuple_of_interest(Relation rel, int2vector *pkattnums, int16 pknumatts, char **src_pkattvals)
+get_tuple_of_interest(Relation rel, int *pkattnums, int pknumatts, char **src_pkattvals)
 {
    char       *relname;
    TupleDesc   tupdesc;
@@ -2105,13 +2048,13 @@ get_tuple_of_interest(Relation rel, int2vector *pkattnums, int16 pknumatts, char
 
    for (i = 0; i < pknumatts; i++)
    {
-       int16       pkattnum = pkattnums->values[i];
+       int         pkattnum = pkattnums[i];
 
        if (i > 0)
            appendStringInfo(&buf, " AND ");
 
        appendStringInfoString(&buf,
-          quote_ident_cstr(NameStr(tupdesc->attrs[pkattnum - 1]->attname)));
+          quote_ident_cstr(NameStr(tupdesc->attrs[pkattnum]->attname)));
 
        if (src_pkattvals[i] != NULL)
            appendStringInfo(&buf, " = %s",
@@ -2445,24 +2388,52 @@ dblink_security_check(PGconn *conn, remoteConn *rconn, const char *connstr)
    }
 }
 
-static int
-get_nondropped_natts(Relation rel)
+/*
+ * Validate the PK-attnums argument for dblink_build_sql_insert() and related
+ * functions, and translate to the internal representation.
+ *
+ * The user supplies an int2vector of 1-based physical attnums, plus a count
+ * argument (the need for the separate count argument is historical, but we
+ * still check it).  We check that each attnum corresponds to a valid,
+ * non-dropped attribute of the rel.  We do *not* prevent attnums from being
+ * listed twice, though the actual use-case for such things is dubious.
+ *
+ * The internal representation is a palloc'd int array of 0-based physical
+ * attnums.
+ */
+static void
+validate_pkattnums(Relation rel,
+                  int2vector *pkattnums_arg, int32 pknumatts_arg,
+                  int **pkattnums, int *pknumatts)
 {
-   int         nondropped_natts = 0;
-   TupleDesc   tupdesc;
-   int         natts;
+   TupleDesc   tupdesc = rel->rd_att;
+   int         natts = tupdesc->natts;
    int         i;
 
-   tupdesc = rel->rd_att;
-   natts = tupdesc->natts;
+   /* Don't take more array elements than there are */
+   pknumatts_arg = Min(pknumatts_arg, pkattnums_arg->dim1);
 
-   for (i = 0; i < natts; i++)
+   /* Must have at least one pk attnum selected */
+   if (pknumatts_arg <= 0)
+       ereport(ERROR,
+               (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
+                errmsg("number of key attributes must be > 0")));
+
+   /* Allocate output array */
+   *pkattnums = (int *) palloc(pknumatts_arg * sizeof(int));
+   *pknumatts = pknumatts_arg;
+
+   /* Validate attnums and convert to internal form */
+   for (i = 0; i < pknumatts_arg; i++)
    {
-       if (tupdesc->attrs[i]->attisdropped)
-           continue;
-       nondropped_natts++;
-   }
+       int     pkattnum = pkattnums_arg->values[i];
 
-   return nondropped_natts;
+       if (pkattnum <= 0 || pkattnum > natts ||
+           tupdesc->attrs[pkattnum - 1]->attisdropped)
+           ereport(ERROR,
+                   (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
+                    errmsg("invalid attribute number %d", pkattnum)));
+       (*pkattnums)[i] = pkattnum - 1;
+   }
 }
 
index 6a08581ec8fb65c4a8d144b163a8f99cda90e662..4804d8f297de3e1346abfa1d262f8acf0adeca1c 100644 (file)
@@ -46,7 +46,7 @@ SELECT dblink_build_sql_insert('foo','1 2',2,'{"0", "a"}','{"99", "xyz"}');
 
 -- too many pk fields, should fail
 SELECT dblink_build_sql_insert('foo','1 2 3 4',4,'{"0", "a", "{a0,b0,c0}"}','{"99", "xyz", "{za0,zb0,zc0}"}');
-ERROR:  number of primary key fields exceeds number of specified relation attributes
+ERROR:  invalid attribute number 4
 -- build an update statement based on a local tuple,
 -- replacing the primary key values with new ones
 SELECT dblink_build_sql_update('foo','1 2',2,'{"0", "a"}','{"99", "xyz"}');
@@ -57,7 +57,7 @@ SELECT dblink_build_sql_update('foo','1 2',2,'{"0", "a"}','{"99", "xyz"}');
 
 -- too many pk fields, should fail
 SELECT dblink_build_sql_update('foo','1 2 3 4',4,'{"0", "a", "{a0,b0,c0}"}','{"99", "xyz", "{za0,zb0,zc0}"}');
-ERROR:  number of primary key fields exceeds number of specified relation attributes
+ERROR:  invalid attribute number 4
 -- build a delete statement based on a local tuple,
 SELECT dblink_build_sql_delete('foo','1 2',2,'{"0", "a"}');
            dblink_build_sql_delete           
@@ -67,7 +67,7 @@ SELECT dblink_build_sql_delete('foo','1 2',2,'{"0", "a"}');
 
 -- too many pk fields, should fail
 SELECT dblink_build_sql_delete('foo','1 2 3 4',4,'{"0", "a", "{a0,b0,c0}"}');
-ERROR:  number of primary key fields exceeds number of specified relation attributes
+ERROR:  invalid attribute number 4
 -- retest using a quoted and schema qualified table
 CREATE SCHEMA "MySchema";
 CREATE TABLE "MySchema"."Foo"(f1 int, f2 text, f3 text[], primary key (f1,f2));
@@ -726,7 +726,7 @@ UNION
 (SELECT * from dblink_get_result('dtest3') as t3(f1 int, f2 text, f3 text[]))
 ORDER by f1;
 SELECT dblink_get_connections();
- dblink_get_connections
+ dblink_get_connections 
 ------------------------
  {dtest1,dtest2,dtest3}
 (1 row)