summaryrefslogtreecommitdiff
path: root/src/common
diff options
context:
space:
mode:
Diffstat (limited to 'src/common')
-rw-r--r--src/common/base64.c59
-rw-r--r--src/common/scram-common.c59
2 files changed, 104 insertions, 14 deletions
diff --git a/src/common/base64.c b/src/common/base64.c
index 55c8983f97d..57ec06c3a95 100644
--- a/src/common/base64.c
+++ b/src/common/base64.c
@@ -42,10 +42,11 @@ static const int8 b64lookup[128] = {
* pg_b64_encode
*
* Encode into base64 the given string. Returns the length of the encoded
- * string.
+ * string, and -1 in the event of an error with the result buffer zeroed
+ * for safety.
*/
int
-pg_b64_encode(const char *src, int len, char *dst)
+pg_b64_encode(const char *src, int len, char *dst, int dstlen)
{
char *p;
const char *s,
@@ -65,6 +66,13 @@ pg_b64_encode(const char *src, int len, char *dst)
/* write it out */
if (pos < 0)
{
+ /*
+ * Leave if there is an overflow in the area allocated for the
+ * encoded string.
+ */
+ if ((p - dst + 4) > dstlen)
+ goto error;
+
*p++ = _base64[(buf >> 18) & 0x3f];
*p++ = _base64[(buf >> 12) & 0x3f];
*p++ = _base64[(buf >> 6) & 0x3f];
@@ -76,23 +84,36 @@ pg_b64_encode(const char *src, int len, char *dst)
}
if (pos != 2)
{
+ /*
+ * Leave if there is an overflow in the area allocated for the encoded
+ * string.
+ */
+ if ((p - dst + 4) > dstlen)
+ goto error;
+
*p++ = _base64[(buf >> 18) & 0x3f];
*p++ = _base64[(buf >> 12) & 0x3f];
*p++ = (pos == 0) ? _base64[(buf >> 6) & 0x3f] : '=';
*p++ = '=';
}
+ Assert((p - dst) <= dstlen);
return p - dst;
+
+error:
+ memset(dst, 0, dstlen);
+ return -1;
}
/*
* pg_b64_decode
*
* Decode the given base64 string. Returns the length of the decoded
- * string on success, and -1 in the event of an error.
+ * string on success, and -1 in the event of an error with the result
+ * buffer zeroed for safety.
*/
int
-pg_b64_decode(const char *src, int len, char *dst)
+pg_b64_decode(const char *src, int len, char *dst, int dstlen)
{
const char *srcend = src + len,
*s = src;
@@ -109,7 +130,7 @@ pg_b64_decode(const char *src, int len, char *dst)
/* Leave if a whitespace is found */
if (c == ' ' || c == '\t' || c == '\n' || c == '\r')
- return -1;
+ goto error;
if (c == '=')
{
@@ -126,7 +147,7 @@ pg_b64_decode(const char *src, int len, char *dst)
* Unexpected "=" character found while decoding base64
* sequence.
*/
- return -1;
+ goto error;
}
}
b = 0;
@@ -139,7 +160,7 @@ pg_b64_decode(const char *src, int len, char *dst)
if (b < 0)
{
/* invalid symbol found */
- return -1;
+ goto error;
}
}
/* add it to buffer */
@@ -147,11 +168,28 @@ pg_b64_decode(const char *src, int len, char *dst)
pos++;
if (pos == 4)
{
+ /*
+ * Leave if there is an overflow in the area allocated for the
+ * decoded string.
+ */
+ if ((p - dst + 1) > dstlen)
+ goto error;
*p++ = (buf >> 16) & 255;
+
if (end == 0 || end > 1)
+ {
+ /* overflow check */
+ if ((p - dst + 1) > dstlen)
+ goto error;
*p++ = (buf >> 8) & 255;
+ }
if (end == 0 || end > 2)
+ {
+ /* overflow check */
+ if ((p - dst + 1) > dstlen)
+ goto error;
*p++ = buf & 255;
+ }
buf = 0;
pos = 0;
}
@@ -163,10 +201,15 @@ pg_b64_decode(const char *src, int len, char *dst)
* base64 end sequence is invalid. Input data is missing padding, is
* truncated or is otherwise corrupted.
*/
- return -1;
+ goto error;
}
+ Assert((p - dst) <= dstlen);
return p - dst;
+
+error:
+ memset(dst, 0, dstlen);
+ return -1;
}
/*
diff --git a/src/common/scram-common.c b/src/common/scram-common.c
index c30dfc97dca..dff9723e67f 100644
--- a/src/common/scram-common.c
+++ b/src/common/scram-common.c
@@ -198,6 +198,10 @@ scram_build_verifier(const char *salt, int saltlen, int iterations,
char *result;
char *p;
int maxlen;
+ int encoded_salt_len;
+ int encoded_stored_len;
+ int encoded_server_len;
+ int encoded_result;
if (iterations <= 0)
iterations = SCRAM_DEFAULT_ITERATIONS;
@@ -215,11 +219,15 @@ scram_build_verifier(const char *salt, int saltlen, int iterations,
* SCRAM-SHA-256$<iteration count>:<salt>$<StoredKey>:<ServerKey>
*----------
*/
+ encoded_salt_len = pg_b64_enc_len(saltlen);
+ encoded_stored_len = pg_b64_enc_len(SCRAM_KEY_LEN);
+ encoded_server_len = pg_b64_enc_len(SCRAM_KEY_LEN);
+
maxlen = strlen("SCRAM-SHA-256") + 1
+ 10 + 1 /* iteration count */
- + pg_b64_enc_len(saltlen) + 1 /* Base64-encoded salt */
- + pg_b64_enc_len(SCRAM_KEY_LEN) + 1 /* Base64-encoded StoredKey */
- + pg_b64_enc_len(SCRAM_KEY_LEN) + 1; /* Base64-encoded ServerKey */
+ + encoded_salt_len + 1 /* Base64-encoded salt */
+ + encoded_stored_len + 1 /* Base64-encoded StoredKey */
+ + encoded_server_len + 1; /* Base64-encoded ServerKey */
#ifdef FRONTEND
result = malloc(maxlen);
@@ -231,11 +239,50 @@ scram_build_verifier(const char *salt, int saltlen, int iterations,
p = result + sprintf(result, "SCRAM-SHA-256$%d:", iterations);
- p += pg_b64_encode(salt, saltlen, p);
+ /* salt */
+ encoded_result = pg_b64_encode(salt, saltlen, p, encoded_salt_len);
+ if (encoded_result < 0)
+ {
+#ifdef FRONTEND
+ free(result);
+ return NULL;
+#else
+ elog(ERROR, "could not encode salt");
+#endif
+ }
+ p += encoded_result;
*(p++) = '$';
- p += pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p);
+
+ /* stored key */
+ encoded_result = pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p,
+ encoded_stored_len);
+ if (encoded_result < 0)
+ {
+#ifdef FRONTEND
+ free(result);
+ return NULL;
+#else
+ elog(ERROR, "could not encode stored key");
+#endif
+ }
+
+ p += encoded_result;
*(p++) = ':';
- p += pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p);
+
+ /* server key */
+ encoded_result = pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p,
+ encoded_server_len);
+ if (encoded_result < 0)
+ {
+#ifdef FRONTEND
+ free(result);
+ return NULL;
+#else
+ elog(ERROR, "could not encode server key");
+#endif
+ }
+
+ p += encoded_result;
*(p++) = '\0';
Assert(p - result <= maxlen);