Skip to content

Commit

Permalink
Add NULL check for SHA3
Browse files Browse the repository at this point in the history
Signed-off-by: Songling Han <[email protected]>
  • Loading branch information
songlingatpan committed Sep 18, 2024
1 parent ac0ebf9 commit ad0e432
Show file tree
Hide file tree
Showing 9 changed files with 819 additions and 185 deletions.
54 changes: 48 additions & 6 deletions src/common/sha3/ossl_sha3.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
static void do_hash(uint8_t *output, const uint8_t *input, size_t inplen, const EVP_MD *md) {
EVP_MD_CTX *mdctx;
mdctx = OSSL_FUNC(EVP_MD_CTX_new)();
if (mdctx == NULL) {
return;
}
OSSL_FUNC(EVP_DigestInit_ex)(mdctx, md, NULL);
OSSL_FUNC(EVP_DigestUpdate)(mdctx, input, inplen);
OSSL_FUNC(EVP_DigestFinal_ex)(mdctx, output, NULL);
Expand All @@ -26,6 +29,9 @@ static void do_hash(uint8_t *output, const uint8_t *input, size_t inplen, const
static void do_xof(uint8_t *output, size_t outlen, const uint8_t *input, size_t inplen, const EVP_MD *md) {
EVP_MD_CTX *mdctx;
mdctx = OSSL_FUNC(EVP_MD_CTX_new)();
if (mdctx == NULL) {
return;
}
OSSL_FUNC(EVP_DigestInit_ex)(mdctx, md, NULL);
OSSL_FUNC(EVP_DigestUpdate)(mdctx, input, inplen);
OSSL_FUNC(EVP_DigestFinalXOF)(mdctx, output, outlen);
Expand All @@ -42,6 +48,9 @@ static void SHA3_sha3_256(uint8_t *output, const uint8_t *input, size_t inplen)

static void SHA3_sha3_256_inc_init(OQS_SHA3_sha3_256_inc_ctx *state) {
state->ctx = OSSL_FUNC(EVP_MD_CTX_new)();
if (state->ctx == NULL) {
return;
}
EVP_MD_CTX *s = (EVP_MD_CTX *)state->ctx;
OSSL_FUNC(EVP_DigestInit_ex)(s, oqs_sha3_256(), NULL);
}
Expand All @@ -55,7 +64,10 @@ static void SHA3_sha3_256_inc_finalize(uint8_t *output, OQS_SHA3_sha3_256_inc_ct
}

static void SHA3_sha3_256_inc_ctx_release(OQS_SHA3_sha3_256_inc_ctx *state) {
OSSL_FUNC(EVP_MD_CTX_free)((EVP_MD_CTX *)state->ctx);
if (state->ctx != NULL) {
OSSL_FUNC(EVP_MD_CTX_free)((EVP_MD_CTX *)state->ctx);
state->ctx = NULL;
}
}

static void SHA3_sha3_256_inc_ctx_clone(OQS_SHA3_sha3_256_inc_ctx *dest, const OQS_SHA3_sha3_256_inc_ctx *src) {
Expand All @@ -77,6 +89,9 @@ static void SHA3_sha3_384(uint8_t *output, const uint8_t *input, size_t inplen)
/* SHA3-384 incremental */
static void SHA3_sha3_384_inc_init(OQS_SHA3_sha3_384_inc_ctx *state) {
state->ctx = OSSL_FUNC(EVP_MD_CTX_new)();
if (state->ctx == NULL) {
return;
}
OSSL_FUNC(EVP_DigestInit_ex)((EVP_MD_CTX *)state->ctx, oqs_sha3_384(), NULL);
}

Expand All @@ -89,7 +104,10 @@ static void SHA3_sha3_384_inc_finalize(uint8_t *output, OQS_SHA3_sha3_384_inc_ct
}

static void SHA3_sha3_384_inc_ctx_release(OQS_SHA3_sha3_384_inc_ctx *state) {
OSSL_FUNC(EVP_MD_CTX_free)((EVP_MD_CTX *)state->ctx);
if (state->ctx != NULL) {
OSSL_FUNC(EVP_MD_CTX_free)((EVP_MD_CTX *)state->ctx);
state->ctx = NULL;
}
}

static void SHA3_sha3_384_inc_ctx_clone(OQS_SHA3_sha3_384_inc_ctx *dest, const OQS_SHA3_sha3_384_inc_ctx *src) {
Expand All @@ -112,6 +130,9 @@ static void SHA3_sha3_512(uint8_t *output, const uint8_t *input, size_t inplen)

static void SHA3_sha3_512_inc_init(OQS_SHA3_sha3_512_inc_ctx *state) {
state->ctx = OSSL_FUNC(EVP_MD_CTX_new)();
if (state->ctx == NULL) {
return;
}
OSSL_FUNC(EVP_DigestInit_ex)((EVP_MD_CTX *)state->ctx, oqs_sha3_512(), NULL);
}

Expand All @@ -124,7 +145,10 @@ static void SHA3_sha3_512_inc_finalize(uint8_t *output, OQS_SHA3_sha3_512_inc_ct
}

static void SHA3_sha3_512_inc_ctx_release(OQS_SHA3_sha3_512_inc_ctx *state) {
OSSL_FUNC(EVP_MD_CTX_free)((EVP_MD_CTX *)state->ctx);
if (state->ctx != NULL) {
OSSL_FUNC(EVP_MD_CTX_free)((EVP_MD_CTX *)state->ctx);
state->ctx = NULL;
}
}

static void SHA3_sha3_512_inc_ctx_clone(OQS_SHA3_sha3_512_inc_ctx *dest, const OQS_SHA3_sha3_512_inc_ctx *src) {
Expand Down Expand Up @@ -169,9 +193,17 @@ typedef struct {

static void SHA3_shake128_inc_init(OQS_SHA3_shake128_inc_ctx *state) {
state->ctx = OQS_MEM_malloc(sizeof(intrn_shake128_inc_ctx));
if (state->ctx == NULL) {
return;
}

intrn_shake128_inc_ctx *s = (intrn_shake128_inc_ctx *)state->ctx;
s->mdctx = OSSL_FUNC(EVP_MD_CTX_new)();
if (s->mdctx == NULL) {
OQS_MEM_free(state->ctx);
state->ctx = NULL;
return;
}
s->n_out = 0;
OSSL_FUNC(EVP_DigestInit_ex)(s->mdctx, oqs_shake128(), NULL);
}
Expand All @@ -193,15 +225,22 @@ static void SHA3_shake128_inc_squeeze(uint8_t *output, size_t outlen, OQS_SHA3_s
EVP_MD_CTX *clone;

clone = OSSL_FUNC(EVP_MD_CTX_new)();
if (clone == NULL) {
return;
}
OSSL_FUNC(EVP_DigestInit_ex)(clone, oqs_shake128(), NULL);
OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx);
if (s->n_out == 0) {
OSSL_FUNC(EVP_DigestFinalXOF)(clone, output, outlen);
} else {
uint8_t *tmp = OQS_MEM_checked_malloc(s->n_out + outlen);
if (tmp == NULL) {
OSSL_FUNC(EVP_MD_CTX_free)(clone);
return;
}
OSSL_FUNC(EVP_DigestFinalXOF)(clone, tmp, s->n_out + outlen);
memcpy(output, tmp + s->n_out, outlen);
OQS_MEM_free(tmp); // IGNORE free-check
OQS_MEM_secure_free(tmp, s->n_out + outlen);
}
OSSL_FUNC(EVP_MD_CTX_free)(clone);
s->n_out += outlen;
Expand All @@ -210,8 +249,11 @@ static void SHA3_shake128_inc_squeeze(uint8_t *output, size_t outlen, OQS_SHA3_s

static void SHA3_shake128_inc_ctx_release(OQS_SHA3_shake128_inc_ctx *state) {
intrn_shake128_inc_ctx *s = (intrn_shake128_inc_ctx *)state->ctx;
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx);
OQS_MEM_free(s); // IGNORE free-check
if (s != NULL) {
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx);
OQS_MEM_secure_free(s, sizeof(intrn_shake128_inc_ctx));
state->ctx = NULL;
}
}

static void SHA3_shake128_inc_ctx_clone(OQS_SHA3_shake128_inc_ctx *dest, const OQS_SHA3_shake128_inc_ctx *src) {
Expand Down
102 changes: 97 additions & 5 deletions src/common/sha3/ossl_sha3x4.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,28 @@ typedef struct {
} intrn_shake128_x4_inc_ctx;

static void SHA3_shake128_x4_inc_init(OQS_SHA3_shake128_x4_inc_ctx *state) {
if (state == NULL) {
return;
}
state->ctx = OQS_MEM_malloc(sizeof(intrn_shake128_x4_inc_ctx));
if (state->ctx == NULL) {
return;
}

intrn_shake128_x4_inc_ctx *s = (intrn_shake128_x4_inc_ctx *)state->ctx;
s->mdctx0 = OSSL_FUNC(EVP_MD_CTX_new)();
s->mdctx1 = OSSL_FUNC(EVP_MD_CTX_new)();
s->mdctx2 = OSSL_FUNC(EVP_MD_CTX_new)();
s->mdctx3 = OSSL_FUNC(EVP_MD_CTX_new)();
if (s->mdctx0 == NULL || s->mdctx1 == NULL || s->mdctx2 == NULL || s->mdctx3 == NULL) {
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx0);
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx1);
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx2);
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx3);
OQS_MEM_free(s);
state->ctx = NULL;
return;
}
OSSL_FUNC(EVP_DigestInit_ex)(s->mdctx0, oqs_shake128(), NULL);
OSSL_FUNC(EVP_DigestInit_ex)(s->mdctx1, oqs_shake128(), NULL);
OSSL_FUNC(EVP_DigestInit_ex)(s->mdctx2, oqs_shake128(), NULL);
Expand All @@ -48,6 +63,9 @@ static void SHA3_shake128_x4_inc_init(OQS_SHA3_shake128_x4_inc_ctx *state) {
}

static void SHA3_shake128_x4_inc_absorb(OQS_SHA3_shake128_x4_inc_ctx *state, const uint8_t *in0, const uint8_t *in1, const uint8_t *in2, const uint8_t *in3, size_t inplen) {
if (state == NULL || state->ctx == NULL) {
return;
}
intrn_shake128_x4_inc_ctx *s = (intrn_shake128_x4_inc_ctx *)state->ctx;
OSSL_FUNC(EVP_DigestUpdate)(s->mdctx0, in0, inplen);
OSSL_FUNC(EVP_DigestUpdate)(s->mdctx1, in1, inplen);
Expand All @@ -60,6 +78,9 @@ static void SHA3_shake128_x4_inc_finalize(OQS_SHA3_shake128_x4_inc_ctx *state) {
}

static void SHA3_shake128_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, size_t outlen, OQS_SHA3_shake128_x4_inc_ctx *state) {
if (state == NULL || state->ctx == NULL) {
return;
}
intrn_shake128_x4_inc_ctx *s = (intrn_shake128_x4_inc_ctx *)state->ctx;
#if OPENSSL_VERSION_NUMBER >= 0x30300000L
EVP_DigestSqueeze(s->mdctx0, out0, outlen);
Expand All @@ -70,6 +91,9 @@ static void SHA3_shake128_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *
EVP_MD_CTX *clone;

clone = OSSL_FUNC(EVP_MD_CTX_new)();
if (clone == NULL) {
return;
}
OSSL_FUNC(EVP_DigestInit_ex)(clone, oqs_shake128(), NULL);
if (s->n_out == 0) {
OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx0);
Expand All @@ -82,6 +106,10 @@ static void SHA3_shake128_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *
OSSL_FUNC(EVP_DigestFinalXOF)(clone, out3, outlen);
} else {
uint8_t *tmp = OQS_MEM_checked_malloc(s->n_out + outlen);
if (tmp == NULL) {
OSSL_FUNC(EVP_MD_CTX_free)(clone);
return;
}
OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx0);
OSSL_FUNC(EVP_DigestFinalXOF)(clone, tmp, s->n_out + outlen);
memcpy(out0, tmp + s->n_out, outlen);
Expand All @@ -94,16 +122,36 @@ static void SHA3_shake128_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *
OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx3);
OSSL_FUNC(EVP_DigestFinalXOF)(clone, tmp, s->n_out + outlen);
memcpy(out3, tmp + s->n_out, outlen);
OQS_MEM_free(tmp); // IGNORE free-check
OQS_MEM_free(tmp);
}
OSSL_FUNC(EVP_MD_CTX_free)(clone);
s->n_out += outlen;
#endif
}

static void SHA3_shake128_x4_inc_ctx_clone(OQS_SHA3_shake128_x4_inc_ctx *dest, const OQS_SHA3_shake128_x4_inc_ctx *src) {
if (dest == NULL || src == NULL || src->ctx == NULL) {
return;
}
dest->ctx = OQS_MEM_malloc(sizeof(intrn_shake128_x4_inc_ctx));
if (dest->ctx == NULL) {
return;
}
intrn_shake128_x4_inc_ctx *s = (intrn_shake128_x4_inc_ctx *)src->ctx;
intrn_shake128_x4_inc_ctx *d = (intrn_shake128_x4_inc_ctx *)dest->ctx;
d->mdctx0 = OSSL_FUNC(EVP_MD_CTX_new)();
d->mdctx1 = OSSL_FUNC(EVP_MD_CTX_new)();
d->mdctx2 = OSSL_FUNC(EVP_MD_CTX_new)();
d->mdctx3 = OSSL_FUNC(EVP_MD_CTX_new)();
if (d->mdctx0 == NULL || d->mdctx1 == NULL || d->mdctx2 == NULL || d->mdctx3 == NULL) {
OSSL_FUNC(EVP_MD_CTX_free)(d->mdctx0);
OSSL_FUNC(EVP_MD_CTX_free)(d->mdctx1);
OSSL_FUNC(EVP_MD_CTX_free)(d->mdctx2);
OSSL_FUNC(EVP_MD_CTX_free)(d->mdctx3);
OQS_MEM_free(d);
dest->ctx = NULL;
return;
}
OSSL_FUNC(EVP_MD_CTX_copy_ex)(d->mdctx0, s->mdctx0);
OSSL_FUNC(EVP_MD_CTX_copy_ex)(d->mdctx1, s->mdctx1);
OSSL_FUNC(EVP_MD_CTX_copy_ex)(d->mdctx2, s->mdctx2);
Expand All @@ -112,15 +160,22 @@ static void SHA3_shake128_x4_inc_ctx_clone(OQS_SHA3_shake128_x4_inc_ctx *dest, c
}

static void SHA3_shake128_x4_inc_ctx_release(OQS_SHA3_shake128_x4_inc_ctx *state) {
if (state == NULL || state->ctx == NULL) {
return;
}
intrn_shake128_x4_inc_ctx *s = (intrn_shake128_x4_inc_ctx *)state->ctx;
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx0);
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx1);
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx2);
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx3);
OQS_MEM_free(s); // IGNORE free-check
OQS_MEM_free(s);
state->ctx = NULL;
}

static void SHA3_shake128_x4_inc_ctx_reset(OQS_SHA3_shake128_x4_inc_ctx *state) {
if (state == NULL || state->ctx == NULL) {
return;
}
intrn_shake128_x4_inc_ctx *s = (intrn_shake128_x4_inc_ctx *)state->ctx;
OSSL_FUNC(EVP_MD_CTX_reset)(s->mdctx0);
OSSL_FUNC(EVP_MD_CTX_reset)(s->mdctx1);
Expand Down Expand Up @@ -152,15 +207,29 @@ typedef struct {
EVP_MD_CTX *mdctx3;
size_t n_out;
} intrn_shake256_x4_inc_ctx;

static void SHA3_shake256_x4_inc_init(OQS_SHA3_shake256_x4_inc_ctx *state) {
if (state == NULL) {
return;
}
state->ctx = OQS_MEM_malloc(sizeof(intrn_shake256_x4_inc_ctx));
if (state->ctx == NULL) {
return;
}

intrn_shake256_x4_inc_ctx *s = (intrn_shake256_x4_inc_ctx *)state->ctx;
s->mdctx0 = OSSL_FUNC(EVP_MD_CTX_new)();
s->mdctx1 = OSSL_FUNC(EVP_MD_CTX_new)();
s->mdctx2 = OSSL_FUNC(EVP_MD_CTX_new)();
s->mdctx3 = OSSL_FUNC(EVP_MD_CTX_new)();
if (s->mdctx0 == NULL || s->mdctx1 == NULL || s->mdctx2 == NULL || s->mdctx3 == NULL) {
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx0);
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx1);
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx2);
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx3);
OQS_MEM_free(s);
state->ctx = NULL;
return;
}
OSSL_FUNC(EVP_DigestInit_ex)(s->mdctx0, oqs_shake256(), NULL);
OSSL_FUNC(EVP_DigestInit_ex)(s->mdctx1, oqs_shake256(), NULL);
OSSL_FUNC(EVP_DigestInit_ex)(s->mdctx2, oqs_shake256(), NULL);
Expand All @@ -169,6 +238,9 @@ static void SHA3_shake256_x4_inc_init(OQS_SHA3_shake256_x4_inc_ctx *state) {
}

static void SHA3_shake256_x4_inc_absorb(OQS_SHA3_shake256_x4_inc_ctx *state, const uint8_t *in0, const uint8_t *in1, const uint8_t *in2, const uint8_t *in3, size_t inplen) {
if (state == NULL || state->ctx == NULL) {
return;
}
intrn_shake256_x4_inc_ctx *s = (intrn_shake256_x4_inc_ctx *)state->ctx;
OSSL_FUNC(EVP_DigestUpdate)(s->mdctx0, in0, inplen);
OSSL_FUNC(EVP_DigestUpdate)(s->mdctx1, in1, inplen);
Expand All @@ -181,6 +253,9 @@ static void SHA3_shake256_x4_inc_finalize(OQS_SHA3_shake256_x4_inc_ctx *state) {
}

static void SHA3_shake256_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, size_t outlen, OQS_SHA3_shake256_x4_inc_ctx *state) {
if (state == NULL || state->ctx == NULL) {
return;
}
intrn_shake256_x4_inc_ctx *s = (intrn_shake256_x4_inc_ctx *)state->ctx;
#if OPENSSL_VERSION_NUMBER >= 0x30300000L
EVP_DigestSqueeze(s->mdctx0, out0, outlen);
Expand All @@ -191,6 +266,9 @@ static void SHA3_shake256_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *
EVP_MD_CTX *clone;

clone = OSSL_FUNC(EVP_MD_CTX_new)();
if (clone == NULL) {
return;
}
OSSL_FUNC(EVP_DigestInit_ex)(clone, oqs_shake256(), NULL);
if (s->n_out == 0) {
OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx0);
Expand All @@ -203,6 +281,10 @@ static void SHA3_shake256_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *
OSSL_FUNC(EVP_DigestFinalXOF)(clone, out3, outlen);
} else {
uint8_t *tmp = OQS_MEM_checked_malloc(s->n_out + outlen);
if (tmp == NULL) {
OSSL_FUNC(EVP_MD_CTX_free)(clone);
return;
}
OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx0);
OSSL_FUNC(EVP_DigestFinalXOF)(clone, tmp, s->n_out + outlen);
memcpy(out0, tmp + s->n_out, outlen);
Expand All @@ -215,14 +297,17 @@ static void SHA3_shake256_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *
OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx3);
OSSL_FUNC(EVP_DigestFinalXOF)(clone, tmp, s->n_out + outlen);
memcpy(out3, tmp + s->n_out, outlen);
OQS_MEM_free(tmp); // IGNORE free-check
OQS_MEM_free(tmp);
}
OSSL_FUNC(EVP_MD_CTX_free)(clone);
s->n_out += outlen;
#endif
}

static void SHA3_shake256_x4_inc_ctx_clone(OQS_SHA3_shake256_x4_inc_ctx *dest, const OQS_SHA3_shake256_x4_inc_ctx *src) {
if (dest == NULL || src == NULL || dest->ctx == NULL || src->ctx == NULL) {
return;
}
intrn_shake256_x4_inc_ctx *s = (intrn_shake256_x4_inc_ctx *)src->ctx;
intrn_shake256_x4_inc_ctx *d = (intrn_shake256_x4_inc_ctx *)dest->ctx;
OSSL_FUNC(EVP_MD_CTX_copy_ex)(d->mdctx0, s->mdctx0);
Expand All @@ -233,15 +318,22 @@ static void SHA3_shake256_x4_inc_ctx_clone(OQS_SHA3_shake256_x4_inc_ctx *dest, c
}

static void SHA3_shake256_x4_inc_ctx_release(OQS_SHA3_shake256_x4_inc_ctx *state) {
if (state == NULL || state->ctx == NULL) {
return;
}
intrn_shake256_x4_inc_ctx *s = (intrn_shake256_x4_inc_ctx *)state->ctx;
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx0);
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx1);
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx2);
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx3);
OQS_MEM_free(s); // IGNORE free-check
OQS_MEM_free(s);
state->ctx = NULL;
}

static void SHA3_shake256_x4_inc_ctx_reset(OQS_SHA3_shake256_x4_inc_ctx *state) {
if (state == NULL || state->ctx == NULL) {
return;
}
intrn_shake256_x4_inc_ctx *s = (intrn_shake256_x4_inc_ctx *)state->ctx;
OSSL_FUNC(EVP_MD_CTX_reset)(s->mdctx0);
OSSL_FUNC(EVP_MD_CTX_reset)(s->mdctx1);
Expand Down
Loading

0 comments on commit ad0e432

Please sign in to comment.