Improve error handling of HMAC computations
authorMichael Paquier <michael@paquier.xyz>
Thu, 13 Jan 2022 07:17:21 +0000 (16:17 +0900)
committerMichael Paquier <michael@paquier.xyz>
Thu, 13 Jan 2022 07:17:21 +0000 (16:17 +0900)
This is similar to b69aba7, except that this completes the work for
HMAC with a new routine called pg_hmac_error() that would provide more
context about the type of error that happened during a HMAC computation:
- The fallback HMAC implementation in hmac.c relies on cryptohashes, so
in some code paths it is necessary to return back the error generated by
cryptohashes.
- For the OpenSSL implementation (hmac_openssl.c), the logic is very
similar to cryptohash_openssl.c, where the error context comes from
OpenSSL if one of its internal routines failed, with different error
codes if something internal to hmac_openssl.c failed or was incorrect.

Any in-core code paths that use the centralized HMAC interface are
related to SCRAM, for errors that are unlikely going to happen, with
only SHA-256.  It would be possible to see errors when computing some
HMACs with MD5 for example and OpenSSL FIPS enabled, and this commit
would help in reporting the correct errors but nothing in core uses
that.  So, at the end, no backpatch to v14 is done, at least for now.

Errors in SCRAM related to the computation of the server key, stored
key, etc. need to pass down the potential error context string across
more layers of their respective call stacks for the frontend and the
backend, so each surrounding routine is adapted for this purpose.

Reviewed-by: Sergey Shinderuk
Discussion: https://postgr.es/m/Yd0N9tSAIIkFd+qi@paquier.xyz

src/backend/libpq/auth-scram.c
src/common/hmac.c
src/common/hmac_openssl.c
src/common/scram-common.c
src/include/common/hmac.h
src/include/common/scram-common.h
src/interfaces/libpq/fe-auth-scram.c
src/interfaces/libpq/fe-auth.c
src/interfaces/libpq/fe-auth.h
src/tools/pgindent/typedefs.list

index 7c9dee70ce71ebd31972ebb86677ec0709a98888..ee7f52218ab33bc23184b77d2d52e822309271d4 100644 (file)
@@ -465,6 +465,7 @@ pg_be_scram_build_secret(const char *password)
    pg_saslprep_rc rc;
    char        saltbuf[SCRAM_DEFAULT_SALT_LEN];
    char       *result;
+   const char *errstr = NULL;
 
    /*
     * Normalize the password with SASLprep.  If that doesn't work, because
@@ -482,7 +483,8 @@ pg_be_scram_build_secret(const char *password)
                 errmsg("could not generate random salt")));
 
    result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN,
-                               SCRAM_DEFAULT_ITERATIONS, password);
+                               SCRAM_DEFAULT_ITERATIONS, password,
+                               &errstr);
 
    if (prep_password)
        pfree(prep_password);
@@ -509,6 +511,7 @@ scram_verify_plain_password(const char *username, const char *password,
    uint8       computed_key[SCRAM_KEY_LEN];
    char       *prep_password;
    pg_saslprep_rc rc;
+   const char *errstr = NULL;
 
    if (!parse_scram_secret(secret, &iterations, &encoded_salt,
                            stored_key, server_key))
@@ -539,10 +542,10 @@ scram_verify_plain_password(const char *username, const char *password,
 
    /* Compute Server Key based on the user-supplied plaintext password */
    if (scram_SaltedPassword(password, salt, saltlen, iterations,
-                            salted_password) < 0 ||
-       scram_ServerKey(salted_password, computed_key) < 0)
+                            salted_password, &errstr) < 0 ||
+       scram_ServerKey(salted_password, computed_key, &errstr) < 0)
    {
-       elog(ERROR, "could not compute server key");
+       elog(ERROR, "could not compute server key: %s", errstr);
    }
 
    if (prep_password)
@@ -1113,6 +1116,7 @@ verify_client_proof(scram_state *state)
    uint8       client_StoredKey[SCRAM_KEY_LEN];
    pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
    int         i;
+   const char *errstr = NULL;
 
    /*
     * Calculate ClientSignature.  Note that we don't log directly a failure
@@ -1133,7 +1137,8 @@ verify_client_proof(scram_state *state)
                       strlen(state->client_final_message_without_proof)) < 0 ||
        pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0)
    {
-       elog(ERROR, "could not calculate client signature");
+       elog(ERROR, "could not calculate client signature: %s",
+            pg_hmac_error(ctx));
    }
 
    pg_hmac_free(ctx);
@@ -1143,8 +1148,8 @@ verify_client_proof(scram_state *state)
        ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i];
 
    /* Hash it one more time, and compare with StoredKey */
-   if (scram_H(ClientKey, SCRAM_KEY_LEN, client_StoredKey) < 0)
-       elog(ERROR, "could not hash stored key");
+   if (scram_H(ClientKey, SCRAM_KEY_LEN, client_StoredKey, &errstr) < 0)
+       elog(ERROR, "could not hash stored key: %s", errstr);
 
    if (memcmp(client_StoredKey, state->StoredKey, SCRAM_KEY_LEN) != 0)
        return false;
@@ -1389,7 +1394,8 @@ build_server_final_message(scram_state *state)
                       strlen(state->client_final_message_without_proof)) < 0 ||
        pg_hmac_final(ctx, ServerSignature, sizeof(ServerSignature)) < 0)
    {
-       elog(ERROR, "could not calculate server signature");
+       elog(ERROR, "could not calculate server signature: %s",
+            pg_hmac_error(ctx));
    }
 
    pg_hmac_free(ctx);
index d40026d3e99fce83b1363bf940d9496a5dd223d3..a27778e86b34912ee61a419e0ff25177bc82c87d 100644 (file)
 #define FREE(ptr) free(ptr)
 #endif
 
+/* Set of error states */
+typedef enum pg_hmac_errno
+{
+   PG_HMAC_ERROR_NONE = 0,
+   PG_HMAC_ERROR_OOM,
+   PG_HMAC_ERROR_INTERNAL
+} pg_hmac_errno;
+
 /* Internal pg_hmac_ctx structure */
 struct pg_hmac_ctx
 {
    pg_cryptohash_ctx *hash;
    pg_cryptohash_type type;
+   pg_hmac_errno error;
+   const char *errreason;
    int         block_size;
    int         digest_size;
 
@@ -73,6 +83,8 @@ pg_hmac_create(pg_cryptohash_type type)
        return NULL;
    memset(ctx, 0, sizeof(pg_hmac_ctx));
    ctx->type = type;
+   ctx->error = PG_HMAC_ERROR_NONE;
+   ctx->errreason = NULL;
 
    /*
     * Initialize the context data.  This requires to know the digest and
@@ -150,12 +162,16 @@ pg_hmac_init(pg_hmac_ctx *ctx, const uint8 *key, size_t len)
        /* temporary buffer for one-time shrink */
        shrinkbuf = ALLOC(digest_size);
        if (shrinkbuf == NULL)
+       {
+           ctx->error = PG_HMAC_ERROR_OOM;
            return -1;
+       }
        memset(shrinkbuf, 0, digest_size);
 
        hash_ctx = pg_cryptohash_create(ctx->type);
        if (hash_ctx == NULL)
        {
+           ctx->error = PG_HMAC_ERROR_OOM;
            FREE(shrinkbuf);
            return -1;
        }
@@ -164,6 +180,8 @@ pg_hmac_init(pg_hmac_ctx *ctx, const uint8 *key, size_t len)
            pg_cryptohash_update(hash_ctx, key, len) < 0 ||
            pg_cryptohash_final(hash_ctx, shrinkbuf, digest_size) < 0)
        {
+           ctx->error = PG_HMAC_ERROR_INTERNAL;
+           ctx->errreason = pg_cryptohash_error(hash_ctx);
            pg_cryptohash_free(hash_ctx);
            FREE(shrinkbuf);
            return -1;
@@ -184,6 +202,8 @@ pg_hmac_init(pg_hmac_ctx *ctx, const uint8 *key, size_t len)
    if (pg_cryptohash_init(ctx->hash) < 0 ||
        pg_cryptohash_update(ctx->hash, ctx->k_ipad, ctx->block_size) < 0)
    {
+       ctx->error = PG_HMAC_ERROR_INTERNAL;
+       ctx->errreason = pg_cryptohash_error(ctx->hash);
        if (shrinkbuf)
            FREE(shrinkbuf);
        return -1;
@@ -206,7 +226,11 @@ pg_hmac_update(pg_hmac_ctx *ctx, const uint8 *data, size_t len)
        return -1;
 
    if (pg_cryptohash_update(ctx->hash, data, len) < 0)
+   {
+       ctx->error = PG_HMAC_ERROR_INTERNAL;
+       ctx->errreason = pg_cryptohash_error(ctx->hash);
        return -1;
+   }
 
    return 0;
 }
@@ -226,11 +250,16 @@ pg_hmac_final(pg_hmac_ctx *ctx, uint8 *dest, size_t len)
 
    h = ALLOC(ctx->digest_size);
    if (h == NULL)
+   {
+       ctx->error = PG_HMAC_ERROR_OOM;
        return -1;
+   }
    memset(h, 0, ctx->digest_size);
 
    if (pg_cryptohash_final(ctx->hash, h, ctx->digest_size) < 0)
    {
+       ctx->error = PG_HMAC_ERROR_INTERNAL;
+       ctx->errreason = pg_cryptohash_error(ctx->hash);
        FREE(h);
        return -1;
    }
@@ -241,6 +270,8 @@ pg_hmac_final(pg_hmac_ctx *ctx, uint8 *dest, size_t len)
        pg_cryptohash_update(ctx->hash, h, ctx->digest_size) < 0 ||
        pg_cryptohash_final(ctx->hash, dest, len) < 0)
    {
+       ctx->error = PG_HMAC_ERROR_INTERNAL;
+       ctx->errreason = pg_cryptohash_error(ctx->hash);
        FREE(h);
        return -1;
    }
@@ -264,3 +295,36 @@ pg_hmac_free(pg_hmac_ctx *ctx)
    explicit_bzero(ctx, sizeof(pg_hmac_ctx));
    FREE(ctx);
 }
+
+/*
+ * pg_hmac_error
+ *
+ * Returns a static string providing details about an error that happened
+ * during a HMAC computation.
+ */
+const char *
+pg_hmac_error(pg_hmac_ctx *ctx)
+{
+   if (ctx == NULL)
+       return _("out of memory");
+
+   /*
+    * If a reason is provided, rely on it, else fallback to any error code
+    * set.
+    */
+   if (ctx->errreason)
+       return ctx->errreason;
+
+   switch (ctx->error)
+   {
+       case PG_HMAC_ERROR_NONE:
+           return _("success");
+       case PG_HMAC_ERROR_INTERNAL:
+           return _("internal error");
+       case PG_HMAC_ERROR_OOM:
+           return _("out of memory");
+   }
+
+   Assert(false);              /* cannot be reached */
+   return _("success");
+}
index 7efa90c99e6107709e99659e3ce443b90dddc320..44f36d51dcb0fe0d59a90336884e26ce9fda2bea 100644 (file)
@@ -20,6 +20,8 @@
 #include "postgres_fe.h"
 #endif
 
+
+#include <openssl/err.h>
 #include <openssl/hmac.h>
 
 #include "common/hmac.h"
 #define FREE(ptr) free(ptr)
 #endif                         /* FRONTEND */
 
+/* Set of error states */
+typedef enum pg_hmac_errno
+{
+   PG_HMAC_ERROR_NONE = 0,
+   PG_HMAC_ERROR_DEST_LEN,
+   PG_HMAC_ERROR_OPENSSL
+} pg_hmac_errno;
+
 /* Internal pg_hmac_ctx structure */
 struct pg_hmac_ctx
 {
    HMAC_CTX   *hmacctx;
    pg_cryptohash_type type;
+   pg_hmac_errno error;
+   const char *errreason;
 
 #ifndef FRONTEND
    ResourceOwner resowner;
 #endif
 };
 
+static const char *
+SSLerrmessage(unsigned long ecode)
+{
+   if (ecode == 0)
+       return NULL;
+
+   /*
+    * This may return NULL, but we would fall back to a default error path if
+    * that were the case.
+    */
+   return ERR_reason_error_string(ecode);
+}
+
 /*
  * pg_hmac_create
  *
@@ -78,6 +103,8 @@ pg_hmac_create(pg_cryptohash_type type)
    memset(ctx, 0, sizeof(pg_hmac_ctx));
 
    ctx->type = type;
+   ctx->error = PG_HMAC_ERROR_NONE;
+   ctx->errreason = NULL;
 
    /*
     * Initialization takes care of assigning the correct type for OpenSSL.
@@ -152,7 +179,11 @@ pg_hmac_init(pg_hmac_ctx *ctx, const uint8 *key, size_t len)
 
    /* OpenSSL internals return 1 on success, 0 on failure */
    if (status <= 0)
+   {
+       ctx->errreason = SSLerrmessage(ERR_get_error());
+       ctx->error = PG_HMAC_ERROR_OPENSSL;
        return -1;
+   }
 
    return 0;
 }
@@ -174,7 +205,11 @@ pg_hmac_update(pg_hmac_ctx *ctx, const uint8 *data, size_t len)
 
    /* OpenSSL internals return 1 on success, 0 on failure */
    if (status <= 0)
+   {
+       ctx->errreason = SSLerrmessage(ERR_get_error());
+       ctx->error = PG_HMAC_ERROR_OPENSSL;
        return -1;
+   }
    return 0;
 }
 
@@ -196,27 +231,45 @@ pg_hmac_final(pg_hmac_ctx *ctx, uint8 *dest, size_t len)
    {
        case PG_MD5:
            if (len < MD5_DIGEST_LENGTH)
+           {
+               ctx->error = PG_HMAC_ERROR_DEST_LEN;
                return -1;
+           }
            break;
        case PG_SHA1:
            if (len < SHA1_DIGEST_LENGTH)
+           {
+               ctx->error = PG_HMAC_ERROR_DEST_LEN;
                return -1;
+           }
            break;
        case PG_SHA224:
            if (len < PG_SHA224_DIGEST_LENGTH)
+           {
+               ctx->error = PG_HMAC_ERROR_DEST_LEN;
                return -1;
+           }
            break;
        case PG_SHA256:
            if (len < PG_SHA256_DIGEST_LENGTH)
+           {
+               ctx->error = PG_HMAC_ERROR_DEST_LEN;
                return -1;
+           }
            break;
        case PG_SHA384:
            if (len < PG_SHA384_DIGEST_LENGTH)
+           {
+               ctx->error = PG_HMAC_ERROR_DEST_LEN;
                return -1;
+           }
            break;
        case PG_SHA512:
            if (len < PG_SHA512_DIGEST_LENGTH)
+           {
+               ctx->error = PG_HMAC_ERROR_DEST_LEN;
                return -1;
+           }
            break;
    }
 
@@ -224,7 +277,11 @@ pg_hmac_final(pg_hmac_ctx *ctx, uint8 *dest, size_t len)
 
    /* OpenSSL internals return 1 on success, 0 on failure */
    if (status <= 0)
+   {
+       ctx->errreason = SSLerrmessage(ERR_get_error());
+       ctx->error = PG_HMAC_ERROR_OPENSSL;
        return -1;
+   }
    return 0;
 }
 
@@ -252,3 +309,36 @@ pg_hmac_free(pg_hmac_ctx *ctx)
    explicit_bzero(ctx, sizeof(pg_hmac_ctx));
    FREE(ctx);
 }
+
+/*
+ * pg_hmac_error
+ *
+ * Returns a static string providing details about an error that happened
+ * during a HMAC computation.
+ */
+const char *
+pg_hmac_error(pg_hmac_ctx *ctx)
+{
+   if (ctx == NULL)
+       return _("out of memory");
+
+   /*
+    * If a reason is provided, rely on it, else fallback to any error code
+    * set.
+    */
+   if (ctx->errreason)
+       return ctx->errreason;
+
+   switch (ctx->error)
+   {
+       case PG_HMAC_ERROR_NONE:
+           return _("success");
+       case PG_HMAC_ERROR_DEST_LEN:
+           return _("destination buffer too small");
+       case PG_HMAC_ERROR_OPENSSL:
+           return _("OpenSSL failure");
+   }
+
+   Assert(false);              /* cannot be reached */
+   return _("success");
+}
index 23b68b14da79f3a7d669c56019893cef65f44c01..126862592999b509ada293dea13d32123792eb63 100644 (file)
  * Calculate SaltedPassword.
  *
  * The password should already be normalized by SASLprep.  Returns 0 on
- * success, -1 on failure.
+ * success, -1 on failure with *errstr pointing to a message about the
+ * error details.
  */
 int
 scram_SaltedPassword(const char *password,
                     const char *salt, int saltlen, int iterations,
-                    uint8 *result)
+                    uint8 *result, const char **errstr)
 {
    int         password_len = strlen(password);
    uint32      one = pg_hton32(1);
@@ -44,7 +45,10 @@ scram_SaltedPassword(const char *password,
    pg_hmac_ctx *hmac_ctx = pg_hmac_create(PG_SHA256);
 
    if (hmac_ctx == NULL)
+   {
+       *errstr = pg_hmac_error(NULL);  /* returns OOM */
        return -1;
+   }
 
    /*
     * Iterate hash calculation of HMAC entry using given salt.  This is
@@ -58,6 +62,7 @@ scram_SaltedPassword(const char *password,
        pg_hmac_update(hmac_ctx, (uint8 *) &one, sizeof(uint32)) < 0 ||
        pg_hmac_final(hmac_ctx, Ui_prev, sizeof(Ui_prev)) < 0)
    {
+       *errstr = pg_hmac_error(hmac_ctx);
        pg_hmac_free(hmac_ctx);
        return -1;
    }
@@ -71,6 +76,7 @@ scram_SaltedPassword(const char *password,
            pg_hmac_update(hmac_ctx, (uint8 *) Ui_prev, SCRAM_KEY_LEN) < 0 ||
            pg_hmac_final(hmac_ctx, Ui, sizeof(Ui)) < 0)
        {
+           *errstr = pg_hmac_error(hmac_ctx);
            pg_hmac_free(hmac_ctx);
            return -1;
        }
@@ -87,21 +93,26 @@ scram_SaltedPassword(const char *password,
 
 /*
  * Calculate SHA-256 hash for a NULL-terminated string. (The NULL terminator is
- * not included in the hash).  Returns 0 on success, -1 on failure.
+ * not included in the hash).  Returns 0 on success, -1 on failure with *errstr
+ * pointing to a message about the error details.
  */
 int
-scram_H(const uint8 *input, int len, uint8 *result)
+scram_H(const uint8 *input, int len, uint8 *result, const char **errstr)
 {
    pg_cryptohash_ctx *ctx;
 
    ctx = pg_cryptohash_create(PG_SHA256);
    if (ctx == NULL)
+   {
+       *errstr = pg_cryptohash_error(NULL);    /* returns OOM */
        return -1;
+   }
 
    if (pg_cryptohash_init(ctx) < 0 ||
        pg_cryptohash_update(ctx, input, len) < 0 ||
        pg_cryptohash_final(ctx, result, SCRAM_KEY_LEN) < 0)
    {
+       *errstr = pg_cryptohash_error(ctx);
        pg_cryptohash_free(ctx);
        return -1;
    }
@@ -111,20 +122,26 @@ scram_H(const uint8 *input, int len, uint8 *result)
 }
 
 /*
- * Calculate ClientKey.  Returns 0 on success, -1 on failure.
+ * Calculate ClientKey.  Returns 0 on success, -1 on failure with *errstr
+ * pointing to a message about the error details.
  */
 int
-scram_ClientKey(const uint8 *salted_password, uint8 *result)
+scram_ClientKey(const uint8 *salted_password, uint8 *result,
+               const char **errstr)
 {
    pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
 
    if (ctx == NULL)
+   {
+       *errstr = pg_hmac_error(NULL);  /* returns OOM */
        return -1;
+   }
 
    if (pg_hmac_init(ctx, salted_password, SCRAM_KEY_LEN) < 0 ||
        pg_hmac_update(ctx, (uint8 *) "Client Key", strlen("Client Key")) < 0 ||
        pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0)
    {
+       *errstr = pg_hmac_error(ctx);
        pg_hmac_free(ctx);
        return -1;
    }
@@ -134,20 +151,26 @@ scram_ClientKey(const uint8 *salted_password, uint8 *result)
 }
 
 /*
- * Calculate ServerKey.  Returns 0 on success, -1 on failure.
+ * Calculate ServerKey.  Returns 0 on success, -1 on failure with *errstr
+ * pointing to a message about the error details.
  */
 int
-scram_ServerKey(const uint8 *salted_password, uint8 *result)
+scram_ServerKey(const uint8 *salted_password, uint8 *result,
+               const char **errstr)
 {
    pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
 
    if (ctx == NULL)
+   {
+       *errstr = pg_hmac_error(NULL);  /* returns OOM */
        return -1;
+   }
 
    if (pg_hmac_init(ctx, salted_password, SCRAM_KEY_LEN) < 0 ||
        pg_hmac_update(ctx, (uint8 *) "Server Key", strlen("Server Key")) < 0 ||
        pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0)
    {
+       *errstr = pg_hmac_error(ctx);
        pg_hmac_free(ctx);
        return -1;
    }
@@ -164,10 +187,13 @@ scram_ServerKey(const uint8 *salted_password, uint8 *result)
  *
  * If iterations is 0, default number of iterations is used.  The result is
  * palloc'd or malloc'd, so caller is responsible for freeing it.
+ *
+ * On error, returns NULL and sets *errstr to point to a message about the
+ * error details.
  */
 char *
 scram_build_secret(const char *salt, int saltlen, int iterations,
-                  const char *password)
+                  const char *password, const char **errstr)
 {
    uint8       salted_password[SCRAM_KEY_LEN];
    uint8       stored_key[SCRAM_KEY_LEN];
@@ -185,15 +211,17 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 
    /* Calculate StoredKey and ServerKey */
    if (scram_SaltedPassword(password, salt, saltlen, iterations,
-                            salted_password) < 0 ||
-       scram_ClientKey(salted_password, stored_key) < 0 ||
-       scram_H(stored_key, SCRAM_KEY_LEN, stored_key) < 0 ||
-       scram_ServerKey(salted_password, server_key) < 0)
+                            salted_password, errstr) < 0 ||
+       scram_ClientKey(salted_password, stored_key, errstr) < 0 ||
+       scram_H(stored_key, SCRAM_KEY_LEN, stored_key, errstr) < 0 ||
+       scram_ServerKey(salted_password, server_key, errstr) < 0)
    {
+       /* errstr is filled already here */
 #ifdef FRONTEND
        return NULL;
 #else
-       elog(ERROR, "could not calculate stored key and server key");
+       elog(ERROR, "could not calculate stored key and server key: %s",
+            *errstr);
 #endif
    }
 
@@ -215,7 +243,10 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 #ifdef FRONTEND
    result = malloc(maxlen);
    if (!result)
+   {
+       *errstr = _("out of memory");
        return NULL;
+   }
 #else
    result = palloc(maxlen);
 #endif
@@ -226,11 +257,12 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
    encoded_result = pg_b64_encode(salt, saltlen, p, encoded_salt_len);
    if (encoded_result < 0)
    {
+       *errstr = _("could not encode salt");
 #ifdef FRONTEND
        free(result);
        return NULL;
 #else
-       elog(ERROR, "could not encode salt");
+       elog(ERROR, "%s", *errstr);
 #endif
    }
    p += encoded_result;
@@ -241,11 +273,12 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
                                   encoded_stored_len);
    if (encoded_result < 0)
    {
+       *errstr = _("could not encode stored key");
 #ifdef FRONTEND
        free(result);
        return NULL;
 #else
-       elog(ERROR, "could not encode stored key");
+       elog(ERROR, "%s", *errstr);
 #endif
    }
 
@@ -257,11 +290,12 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
                                   encoded_server_len);
    if (encoded_result < 0)
    {
+       *errstr = _("could not encode server key");
 #ifdef FRONTEND
        free(result);
        return NULL;
 #else
-       elog(ERROR, "could not encode server key");
+       elog(ERROR, "%s", *errstr);
 #endif
    }
 
index cf7aa17be4fa93f3bb962c6f29ff70cc0592eb46..c18783fe11fbf9b8c7bf2bbafb33e21eb883845d 100644 (file)
@@ -25,5 +25,6 @@ extern int    pg_hmac_init(pg_hmac_ctx *ctx, const uint8 *key, size_t len);
 extern int pg_hmac_update(pg_hmac_ctx *ctx, const uint8 *data, size_t len);
 extern int pg_hmac_final(pg_hmac_ctx *ctx, uint8 *dest, size_t len);
 extern void pg_hmac_free(pg_hmac_ctx *ctx);
+extern const char *pg_hmac_error(pg_hmac_ctx *ctx);
 
 #endif                         /* PG_HMAC_H */
index d53b4fa7f5b7adb84b2194658faf73dae089440a..d1f840c11c782992380bde08d22a7ed4449c2903 100644 (file)
 #define SCRAM_DEFAULT_ITERATIONS   4096
 
 extern int scram_SaltedPassword(const char *password, const char *salt,
-                                int saltlen, int iterations, uint8 *result);
-extern int scram_H(const uint8 *str, int len, uint8 *result);
-extern int scram_ClientKey(const uint8 *salted_password, uint8 *result);
-extern int scram_ServerKey(const uint8 *salted_password, uint8 *result);
+                                int saltlen, int iterations, uint8 *result,
+                                const char **errstr);
+extern int scram_H(const uint8 *str, int len, uint8 *result,
+                   const char **errstr);
+extern int scram_ClientKey(const uint8 *salted_password, uint8 *result,
+                           const char **errstr);
+extern int scram_ServerKey(const uint8 *salted_password, uint8 *result,
+                           const char **errstr);
 
 extern char *scram_build_secret(const char *salt, int saltlen, int iterations,
-                               const char *password);
+                               const char *password, const char **errstr);
 
 #endif                         /* SCRAM_COMMON_H */
index cc41440c4e63795c96ac892ec4b7232453961e04..e616200704162e948f6b0e303456a7540fd99f40 100644 (file)
@@ -80,10 +80,11 @@ static bool read_server_first_message(fe_scram_state *state, char *input);
 static bool read_server_final_message(fe_scram_state *state, char *input);
 static char *build_client_first_message(fe_scram_state *state);
 static char *build_client_final_message(fe_scram_state *state);
-static bool verify_server_signature(fe_scram_state *state, bool *match);
+static bool verify_server_signature(fe_scram_state *state, bool *match,
+                                   const char **errstr);
 static bool calculate_client_proof(fe_scram_state *state,
                                   const char *client_final_message_without_proof,
-                                  uint8 *result);
+                                  uint8 *result, const char **errstr);
 
 /*
  * Initialize SCRAM exchange status.
@@ -211,6 +212,7 @@ scram_exchange(void *opaq, char *input, int inputlen,
 {
    fe_scram_state *state = (fe_scram_state *) opaq;
    PGconn     *conn = state->conn;
+   const char *errstr = NULL;
 
    *done = false;
    *success = false;
@@ -273,10 +275,10 @@ scram_exchange(void *opaq, char *input, int inputlen,
             * Verify server signature, to make sure we're talking to the
             * genuine server.
             */
-           if (!verify_server_signature(state, success))
+           if (!verify_server_signature(state, success, &errstr))
            {
-               appendPQExpBufferStr(&conn->errorMessage,
-                                    libpq_gettext("could not verify server signature\n"));
+               appendPQExpBuffer(&conn->errorMessage,
+                                 libpq_gettext("could not verify server signature: %s\n"), errstr);
                goto error;
            }
 
@@ -469,6 +471,7 @@ build_client_final_message(fe_scram_state *state)
    uint8       client_proof[SCRAM_KEY_LEN];
    char       *result;
    int         encoded_len;
+   const char *errstr = NULL;
 
    initPQExpBuffer(&buf);
 
@@ -572,11 +575,12 @@ build_client_final_message(fe_scram_state *state)
    /* Append proof to it, to form client-final-message. */
    if (!calculate_client_proof(state,
                                state->client_final_message_without_proof,
-                               client_proof))
+                               client_proof, &errstr))
    {
        termPQExpBuffer(&buf);
-       appendPQExpBufferStr(&conn->errorMessage,
-                            libpq_gettext("could not calculate client proof\n"));
+       appendPQExpBuffer(&conn->errorMessage,
+                         libpq_gettext("could not calculate client proof: %s\n"),
+                         errstr);
        return NULL;
    }
 
@@ -782,12 +786,13 @@ read_server_final_message(fe_scram_state *state, char *input)
 
 /*
  * Calculate the client proof, part of the final exchange message sent
- * by the client.  Returns true on success, false on failure.
+ * by the client.  Returns true on success, false on failure with *errstr
+ * pointing to a message about the error details.
  */
 static bool
 calculate_client_proof(fe_scram_state *state,
                       const char *client_final_message_without_proof,
-                      uint8 *result)
+                      uint8 *result, const char **errstr)
 {
    uint8       StoredKey[SCRAM_KEY_LEN];
    uint8       ClientKey[SCRAM_KEY_LEN];
@@ -797,17 +802,27 @@ calculate_client_proof(fe_scram_state *state,
 
    ctx = pg_hmac_create(PG_SHA256);
    if (ctx == NULL)
+   {
+       *errstr = pg_hmac_error(NULL);  /* returns OOM */
        return false;
+   }
 
    /*
     * Calculate SaltedPassword, and store it in 'state' so that we can reuse
     * it later in verify_server_signature.
     */
    if (scram_SaltedPassword(state->password, state->salt, state->saltlen,
-                            state->iterations, state->SaltedPassword) < 0 ||
-       scram_ClientKey(state->SaltedPassword, ClientKey) < 0 ||
-       scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey) < 0 ||
-       pg_hmac_init(ctx, StoredKey, SCRAM_KEY_LEN) < 0 ||
+                            state->iterations, state->SaltedPassword,
+                            errstr) < 0 ||
+       scram_ClientKey(state->SaltedPassword, ClientKey, errstr) < 0 ||
+       scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey, errstr) < 0)
+   {
+       /* errstr is already filled here */
+       pg_hmac_free(ctx);
+       return false;
+   }
+
+   if (pg_hmac_init(ctx, StoredKey, SCRAM_KEY_LEN) < 0 ||
        pg_hmac_update(ctx,
                       (uint8 *) state->client_first_message_bare,
                       strlen(state->client_first_message_bare)) < 0 ||
@@ -821,6 +836,7 @@ calculate_client_proof(fe_scram_state *state,
                       strlen(client_final_message_without_proof)) < 0 ||
        pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0)
    {
+       *errstr = pg_hmac_error(ctx);
        pg_hmac_free(ctx);
        return false;
    }
@@ -836,10 +852,12 @@ calculate_client_proof(fe_scram_state *state,
  * Validate the server signature, received as part of the final exchange
  * message received from the server.  *match tracks if the server signature
  * matched or not. Returns true if the server signature got verified, and
- * false for a processing error.
+ * false for a processing error with *errstr pointing to a message about the
+ * error details.
  */
 static bool
-verify_server_signature(fe_scram_state *state, bool *match)
+verify_server_signature(fe_scram_state *state, bool *match,
+                       const char **errstr)
 {
    uint8       expected_ServerSignature[SCRAM_KEY_LEN];
    uint8       ServerKey[SCRAM_KEY_LEN];
@@ -847,11 +865,20 @@ verify_server_signature(fe_scram_state *state, bool *match)
 
    ctx = pg_hmac_create(PG_SHA256);
    if (ctx == NULL)
+   {
+       *errstr = pg_hmac_error(NULL);  /* returns OOM */
+       return false;
+   }
+
+   if (scram_ServerKey(state->SaltedPassword, ServerKey, errstr) < 0)
+   {
+       /* errstr is filled already */
+       pg_hmac_free(ctx);
        return false;
+   }
 
-   if (scram_ServerKey(state->SaltedPassword, ServerKey) < 0 ||
    /* calculate ServerSignature */
-       pg_hmac_init(ctx, ServerKey, SCRAM_KEY_LEN) < 0 ||
+   if (pg_hmac_init(ctx, ServerKey, SCRAM_KEY_LEN) < 0 ||
        pg_hmac_update(ctx,
                       (uint8 *) state->client_first_message_bare,
                       strlen(state->client_first_message_bare)) < 0 ||
@@ -866,6 +893,7 @@ verify_server_signature(fe_scram_state *state, bool *match)
        pg_hmac_final(ctx, expected_ServerSignature,
                      sizeof(expected_ServerSignature)) < 0)
    {
+       *errstr = pg_hmac_error(ctx);
        pg_hmac_free(ctx);
        return false;
    }
@@ -883,9 +911,12 @@ verify_server_signature(fe_scram_state *state, bool *match)
 
 /*
  * Build a new SCRAM secret.
+ *
+ * On error, returns NULL and sets *errstr to point to a message about the
+ * error details.
  */
 char *
-pg_fe_scram_build_secret(const char *password)
+pg_fe_scram_build_secret(const char *password, const char **errstr)
 {
    char       *prep_password;
    pg_saslprep_rc rc;
@@ -899,20 +930,25 @@ pg_fe_scram_build_secret(const char *password)
     */
    rc = pg_saslprep(password, &prep_password);
    if (rc == SASLPREP_OOM)
+   {
+       *errstr = _("out of memory");
        return NULL;
+   }
    if (rc == SASLPREP_SUCCESS)
        password = (const char *) prep_password;
 
    /* Generate a random salt */
    if (!pg_strong_random(saltbuf, SCRAM_DEFAULT_SALT_LEN))
    {
+       *errstr = _("failed to generate random salt");
        if (prep_password)
            free(prep_password);
        return NULL;
    }
 
    result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN,
-                               SCRAM_DEFAULT_ITERATIONS, password);
+                               SCRAM_DEFAULT_ITERATIONS, password,
+                               errstr);
 
    if (prep_password)
        free(prep_password);
index 2edc3f48e2e1331922a3d0fd506ef0ade0c7d5da..f8f4111fef7466cf1e60ff8c05206e71301ddfd1 100644 (file)
@@ -1293,11 +1293,13 @@ PQencryptPasswordConn(PGconn *conn, const char *passwd, const char *user,
     */
    if (strcmp(algorithm, "scram-sha-256") == 0)
    {
-       crypt_pwd = pg_fe_scram_build_secret(passwd);
-       /* We assume the only possible failure is OOM */
+       const char *errstr = NULL;
+
+       crypt_pwd = pg_fe_scram_build_secret(passwd, &errstr);
        if (!crypt_pwd)
-           appendPQExpBufferStr(&conn->errorMessage,
-                                libpq_gettext("out of memory\n"));
+           appendPQExpBuffer(&conn->errorMessage,
+                             libpq_gettext("could not encrypt password: %s\n"),
+                             errstr);
    }
    else if (strcmp(algorithm, "md5") == 0)
    {
index f22b3fe6488f71e3970b46436e2497ee377e2b9e..049a8bb1a101aa3bf3a71989a7afeb1b5e1d07e4 100644 (file)
@@ -25,6 +25,7 @@ extern char *pg_fe_getauthname(PQExpBuffer errorMessage);
 
 /* Mechanisms in fe-auth-scram.c */
 extern const pg_fe_sasl_mech pg_scram_mech;
-extern char *pg_fe_scram_build_secret(const char *password);
+extern char *pg_fe_scram_build_secret(const char *password,
+                                     const char **errstr);
 
 #endif                         /* FE_AUTH_H */
index 5015fa7db000a385f6f40b3564b504afc43dc2ac..89249ecc97ce7f363a443c5774f637a23b0c3a63 100644 (file)
@@ -3361,6 +3361,7 @@ pg_fe_sasl_mech
 pg_funcptr_t
 pg_gssinfo
 pg_hmac_ctx
+pg_hmac_errno
 pg_int64
 pg_local_to_utf_combined
 pg_locale_t