diff options
Diffstat (limited to 'src/common')
-rw-r--r-- | src/common/base64.c | 59 | ||||
-rw-r--r-- | src/common/scram-common.c | 59 |
2 files changed, 104 insertions, 14 deletions
diff --git a/src/common/base64.c b/src/common/base64.c index 55c8983f97d..57ec06c3a95 100644 --- a/src/common/base64.c +++ b/src/common/base64.c @@ -42,10 +42,11 @@ static const int8 b64lookup[128] = { * pg_b64_encode * * Encode into base64 the given string. Returns the length of the encoded - * string. + * string, and -1 in the event of an error with the result buffer zeroed + * for safety. */ int -pg_b64_encode(const char *src, int len, char *dst) +pg_b64_encode(const char *src, int len, char *dst, int dstlen) { char *p; const char *s, @@ -65,6 +66,13 @@ pg_b64_encode(const char *src, int len, char *dst) /* write it out */ if (pos < 0) { + /* + * Leave if there is an overflow in the area allocated for the + * encoded string. + */ + if ((p - dst + 4) > dstlen) + goto error; + *p++ = _base64[(buf >> 18) & 0x3f]; *p++ = _base64[(buf >> 12) & 0x3f]; *p++ = _base64[(buf >> 6) & 0x3f]; @@ -76,23 +84,36 @@ pg_b64_encode(const char *src, int len, char *dst) } if (pos != 2) { + /* + * Leave if there is an overflow in the area allocated for the encoded + * string. + */ + if ((p - dst + 4) > dstlen) + goto error; + *p++ = _base64[(buf >> 18) & 0x3f]; *p++ = _base64[(buf >> 12) & 0x3f]; *p++ = (pos == 0) ? _base64[(buf >> 6) & 0x3f] : '='; *p++ = '='; } + Assert((p - dst) <= dstlen); return p - dst; + +error: + memset(dst, 0, dstlen); + return -1; } /* * pg_b64_decode * * Decode the given base64 string. Returns the length of the decoded - * string on success, and -1 in the event of an error. + * string on success, and -1 in the event of an error with the result + * buffer zeroed for safety. */ int -pg_b64_decode(const char *src, int len, char *dst) +pg_b64_decode(const char *src, int len, char *dst, int dstlen) { const char *srcend = src + len, *s = src; @@ -109,7 +130,7 @@ pg_b64_decode(const char *src, int len, char *dst) /* Leave if a whitespace is found */ if (c == ' ' || c == '\t' || c == '\n' || c == '\r') - return -1; + goto error; if (c == '=') { @@ -126,7 +147,7 @@ pg_b64_decode(const char *src, int len, char *dst) * Unexpected "=" character found while decoding base64 * sequence. */ - return -1; + goto error; } } b = 0; @@ -139,7 +160,7 @@ pg_b64_decode(const char *src, int len, char *dst) if (b < 0) { /* invalid symbol found */ - return -1; + goto error; } } /* add it to buffer */ @@ -147,11 +168,28 @@ pg_b64_decode(const char *src, int len, char *dst) pos++; if (pos == 4) { + /* + * Leave if there is an overflow in the area allocated for the + * decoded string. + */ + if ((p - dst + 1) > dstlen) + goto error; *p++ = (buf >> 16) & 255; + if (end == 0 || end > 1) + { + /* overflow check */ + if ((p - dst + 1) > dstlen) + goto error; *p++ = (buf >> 8) & 255; + } if (end == 0 || end > 2) + { + /* overflow check */ + if ((p - dst + 1) > dstlen) + goto error; *p++ = buf & 255; + } buf = 0; pos = 0; } @@ -163,10 +201,15 @@ pg_b64_decode(const char *src, int len, char *dst) * base64 end sequence is invalid. Input data is missing padding, is * truncated or is otherwise corrupted. */ - return -1; + goto error; } + Assert((p - dst) <= dstlen); return p - dst; + +error: + memset(dst, 0, dstlen); + return -1; } /* diff --git a/src/common/scram-common.c b/src/common/scram-common.c index c30dfc97dca..dff9723e67f 100644 --- a/src/common/scram-common.c +++ b/src/common/scram-common.c @@ -198,6 +198,10 @@ scram_build_verifier(const char *salt, int saltlen, int iterations, char *result; char *p; int maxlen; + int encoded_salt_len; + int encoded_stored_len; + int encoded_server_len; + int encoded_result; if (iterations <= 0) iterations = SCRAM_DEFAULT_ITERATIONS; @@ -215,11 +219,15 @@ scram_build_verifier(const char *salt, int saltlen, int iterations, * SCRAM-SHA-256$<iteration count>:<salt>$<StoredKey>:<ServerKey> *---------- */ + encoded_salt_len = pg_b64_enc_len(saltlen); + encoded_stored_len = pg_b64_enc_len(SCRAM_KEY_LEN); + encoded_server_len = pg_b64_enc_len(SCRAM_KEY_LEN); + maxlen = strlen("SCRAM-SHA-256") + 1 + 10 + 1 /* iteration count */ - + pg_b64_enc_len(saltlen) + 1 /* Base64-encoded salt */ - + pg_b64_enc_len(SCRAM_KEY_LEN) + 1 /* Base64-encoded StoredKey */ - + pg_b64_enc_len(SCRAM_KEY_LEN) + 1; /* Base64-encoded ServerKey */ + + encoded_salt_len + 1 /* Base64-encoded salt */ + + encoded_stored_len + 1 /* Base64-encoded StoredKey */ + + encoded_server_len + 1; /* Base64-encoded ServerKey */ #ifdef FRONTEND result = malloc(maxlen); @@ -231,11 +239,50 @@ scram_build_verifier(const char *salt, int saltlen, int iterations, p = result + sprintf(result, "SCRAM-SHA-256$%d:", iterations); - p += pg_b64_encode(salt, saltlen, p); + /* salt */ + encoded_result = pg_b64_encode(salt, saltlen, p, encoded_salt_len); + if (encoded_result < 0) + { +#ifdef FRONTEND + free(result); + return NULL; +#else + elog(ERROR, "could not encode salt"); +#endif + } + p += encoded_result; *(p++) = '$'; - p += pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p); + + /* stored key */ + encoded_result = pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p, + encoded_stored_len); + if (encoded_result < 0) + { +#ifdef FRONTEND + free(result); + return NULL; +#else + elog(ERROR, "could not encode stored key"); +#endif + } + + p += encoded_result; *(p++) = ':'; - p += pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p); + + /* server key */ + encoded_result = pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p, + encoded_server_len); + if (encoded_result < 0) + { +#ifdef FRONTEND + free(result); + return NULL; +#else + elog(ERROR, "could not encode server key"); +#endif + } + + p += encoded_result; *(p++) = '\0'; Assert(p - result <= maxlen); |