diff --git a/src/common/sha3/ossl_sha3.c b/src/common/sha3/ossl_sha3.c index 2ac3e98cb5..c5c9bb53b1 100644 --- a/src/common/sha3/ossl_sha3.c +++ b/src/common/sha3/ossl_sha3.c @@ -17,6 +17,9 @@ static void do_hash(uint8_t *output, const uint8_t *input, size_t inplen, const EVP_MD *md) { EVP_MD_CTX *mdctx; mdctx = OSSL_FUNC(EVP_MD_CTX_new)(); + if (mdctx == NULL) { + return; + } OSSL_FUNC(EVP_DigestInit_ex)(mdctx, md, NULL); OSSL_FUNC(EVP_DigestUpdate)(mdctx, input, inplen); OSSL_FUNC(EVP_DigestFinal_ex)(mdctx, output, NULL); @@ -26,6 +29,9 @@ static void do_hash(uint8_t *output, const uint8_t *input, size_t inplen, const static void do_xof(uint8_t *output, size_t outlen, const uint8_t *input, size_t inplen, const EVP_MD *md) { EVP_MD_CTX *mdctx; mdctx = OSSL_FUNC(EVP_MD_CTX_new)(); + if (mdctx == NULL) { + return; + } OSSL_FUNC(EVP_DigestInit_ex)(mdctx, md, NULL); OSSL_FUNC(EVP_DigestUpdate)(mdctx, input, inplen); OSSL_FUNC(EVP_DigestFinalXOF)(mdctx, output, outlen); @@ -42,6 +48,9 @@ static void SHA3_sha3_256(uint8_t *output, const uint8_t *input, size_t inplen) static void SHA3_sha3_256_inc_init(OQS_SHA3_sha3_256_inc_ctx *state) { state->ctx = OSSL_FUNC(EVP_MD_CTX_new)(); + if (state->ctx == NULL) { + return; + } EVP_MD_CTX *s = (EVP_MD_CTX *)state->ctx; OSSL_FUNC(EVP_DigestInit_ex)(s, oqs_sha3_256(), NULL); } @@ -55,7 +64,10 @@ static void SHA3_sha3_256_inc_finalize(uint8_t *output, OQS_SHA3_sha3_256_inc_ct } static void SHA3_sha3_256_inc_ctx_release(OQS_SHA3_sha3_256_inc_ctx *state) { - OSSL_FUNC(EVP_MD_CTX_free)((EVP_MD_CTX *)state->ctx); + if (state->ctx != NULL) { + OSSL_FUNC(EVP_MD_CTX_free)((EVP_MD_CTX *)state->ctx); + state->ctx = NULL; + } } static void SHA3_sha3_256_inc_ctx_clone(OQS_SHA3_sha3_256_inc_ctx *dest, const OQS_SHA3_sha3_256_inc_ctx *src) { @@ -77,6 +89,9 @@ static void SHA3_sha3_384(uint8_t *output, const uint8_t *input, size_t inplen) /* SHA3-384 incremental */ static void SHA3_sha3_384_inc_init(OQS_SHA3_sha3_384_inc_ctx *state) { state->ctx = OSSL_FUNC(EVP_MD_CTX_new)(); + if (state->ctx == NULL) { + return; + } OSSL_FUNC(EVP_DigestInit_ex)((EVP_MD_CTX *)state->ctx, oqs_sha3_384(), NULL); } @@ -89,7 +104,10 @@ static void SHA3_sha3_384_inc_finalize(uint8_t *output, OQS_SHA3_sha3_384_inc_ct } static void SHA3_sha3_384_inc_ctx_release(OQS_SHA3_sha3_384_inc_ctx *state) { - OSSL_FUNC(EVP_MD_CTX_free)((EVP_MD_CTX *)state->ctx); + if (state->ctx != NULL) { + OSSL_FUNC(EVP_MD_CTX_free)((EVP_MD_CTX *)state->ctx); + state->ctx = NULL; + } } static void SHA3_sha3_384_inc_ctx_clone(OQS_SHA3_sha3_384_inc_ctx *dest, const OQS_SHA3_sha3_384_inc_ctx *src) { @@ -112,6 +130,9 @@ static void SHA3_sha3_512(uint8_t *output, const uint8_t *input, size_t inplen) static void SHA3_sha3_512_inc_init(OQS_SHA3_sha3_512_inc_ctx *state) { state->ctx = OSSL_FUNC(EVP_MD_CTX_new)(); + if (state->ctx == NULL) { + return; + } OSSL_FUNC(EVP_DigestInit_ex)((EVP_MD_CTX *)state->ctx, oqs_sha3_512(), NULL); } @@ -124,7 +145,10 @@ static void SHA3_sha3_512_inc_finalize(uint8_t *output, OQS_SHA3_sha3_512_inc_ct } static void SHA3_sha3_512_inc_ctx_release(OQS_SHA3_sha3_512_inc_ctx *state) { - OSSL_FUNC(EVP_MD_CTX_free)((EVP_MD_CTX *)state->ctx); + if (state->ctx != NULL) { + OSSL_FUNC(EVP_MD_CTX_free)((EVP_MD_CTX *)state->ctx); + state->ctx = NULL; + } } static void SHA3_sha3_512_inc_ctx_clone(OQS_SHA3_sha3_512_inc_ctx *dest, const OQS_SHA3_sha3_512_inc_ctx *src) { @@ -169,9 +193,17 @@ typedef struct { static void SHA3_shake128_inc_init(OQS_SHA3_shake128_inc_ctx *state) { state->ctx = OQS_MEM_malloc(sizeof(intrn_shake128_inc_ctx)); + if (state->ctx == NULL) { + return; + } intrn_shake128_inc_ctx *s = (intrn_shake128_inc_ctx *)state->ctx; s->mdctx = OSSL_FUNC(EVP_MD_CTX_new)(); + if (s->mdctx == NULL) { + OQS_MEM_free(state->ctx); + state->ctx = NULL; + return; + } s->n_out = 0; OSSL_FUNC(EVP_DigestInit_ex)(s->mdctx, oqs_shake128(), NULL); } @@ -193,15 +225,22 @@ static void SHA3_shake128_inc_squeeze(uint8_t *output, size_t outlen, OQS_SHA3_s EVP_MD_CTX *clone; clone = OSSL_FUNC(EVP_MD_CTX_new)(); + if (clone == NULL) { + return; + } OSSL_FUNC(EVP_DigestInit_ex)(clone, oqs_shake128(), NULL); OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx); if (s->n_out == 0) { OSSL_FUNC(EVP_DigestFinalXOF)(clone, output, outlen); } else { uint8_t *tmp = OQS_MEM_checked_malloc(s->n_out + outlen); + if (tmp == NULL) { + OSSL_FUNC(EVP_MD_CTX_free)(clone); + return; + } OSSL_FUNC(EVP_DigestFinalXOF)(clone, tmp, s->n_out + outlen); memcpy(output, tmp + s->n_out, outlen); - OQS_MEM_free(tmp); // IGNORE free-check + OQS_MEM_secure_free(tmp, s->n_out + outlen); } OSSL_FUNC(EVP_MD_CTX_free)(clone); s->n_out += outlen; @@ -210,8 +249,11 @@ static void SHA3_shake128_inc_squeeze(uint8_t *output, size_t outlen, OQS_SHA3_s static void SHA3_shake128_inc_ctx_release(OQS_SHA3_shake128_inc_ctx *state) { intrn_shake128_inc_ctx *s = (intrn_shake128_inc_ctx *)state->ctx; - OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx); - OQS_MEM_free(s); // IGNORE free-check + if (s != NULL) { + OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx); + OQS_MEM_secure_free(s, sizeof(intrn_shake128_inc_ctx)); + state->ctx = NULL; + } } static void SHA3_shake128_inc_ctx_clone(OQS_SHA3_shake128_inc_ctx *dest, const OQS_SHA3_shake128_inc_ctx *src) { diff --git a/src/common/sha3/ossl_sha3x4.c b/src/common/sha3/ossl_sha3x4.c index eb14a9f1fc..6692d729c1 100644 --- a/src/common/sha3/ossl_sha3x4.c +++ b/src/common/sha3/ossl_sha3x4.c @@ -33,13 +33,28 @@ typedef struct { } intrn_shake128_x4_inc_ctx; static void SHA3_shake128_x4_inc_init(OQS_SHA3_shake128_x4_inc_ctx *state) { + if (state == NULL) { + return; + } state->ctx = OQS_MEM_malloc(sizeof(intrn_shake128_x4_inc_ctx)); + if (state->ctx == NULL) { + return; + } intrn_shake128_x4_inc_ctx *s = (intrn_shake128_x4_inc_ctx *)state->ctx; s->mdctx0 = OSSL_FUNC(EVP_MD_CTX_new)(); s->mdctx1 = OSSL_FUNC(EVP_MD_CTX_new)(); s->mdctx2 = OSSL_FUNC(EVP_MD_CTX_new)(); s->mdctx3 = OSSL_FUNC(EVP_MD_CTX_new)(); + if (s->mdctx0 == NULL || s->mdctx1 == NULL || s->mdctx2 == NULL || s->mdctx3 == NULL) { + OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx0); + OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx1); + OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx2); + OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx3); + OQS_MEM_free(s); + state->ctx = NULL; + return; + } OSSL_FUNC(EVP_DigestInit_ex)(s->mdctx0, oqs_shake128(), NULL); OSSL_FUNC(EVP_DigestInit_ex)(s->mdctx1, oqs_shake128(), NULL); OSSL_FUNC(EVP_DigestInit_ex)(s->mdctx2, oqs_shake128(), NULL); @@ -48,6 +63,9 @@ static void SHA3_shake128_x4_inc_init(OQS_SHA3_shake128_x4_inc_ctx *state) { } static void SHA3_shake128_x4_inc_absorb(OQS_SHA3_shake128_x4_inc_ctx *state, const uint8_t *in0, const uint8_t *in1, const uint8_t *in2, const uint8_t *in3, size_t inplen) { + if (state == NULL || state->ctx == NULL) { + return; + } intrn_shake128_x4_inc_ctx *s = (intrn_shake128_x4_inc_ctx *)state->ctx; OSSL_FUNC(EVP_DigestUpdate)(s->mdctx0, in0, inplen); OSSL_FUNC(EVP_DigestUpdate)(s->mdctx1, in1, inplen); @@ -60,6 +78,9 @@ static void SHA3_shake128_x4_inc_finalize(OQS_SHA3_shake128_x4_inc_ctx *state) { } static void SHA3_shake128_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, size_t outlen, OQS_SHA3_shake128_x4_inc_ctx *state) { + if (state == NULL || state->ctx == NULL) { + return; + } intrn_shake128_x4_inc_ctx *s = (intrn_shake128_x4_inc_ctx *)state->ctx; #if OPENSSL_VERSION_NUMBER >= 0x30300000L EVP_DigestSqueeze(s->mdctx0, out0, outlen); @@ -70,6 +91,9 @@ static void SHA3_shake128_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t * EVP_MD_CTX *clone; clone = OSSL_FUNC(EVP_MD_CTX_new)(); + if (clone == NULL) { + return; + } OSSL_FUNC(EVP_DigestInit_ex)(clone, oqs_shake128(), NULL); if (s->n_out == 0) { OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx0); @@ -82,6 +106,10 @@ static void SHA3_shake128_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t * OSSL_FUNC(EVP_DigestFinalXOF)(clone, out3, outlen); } else { uint8_t *tmp = OQS_MEM_checked_malloc(s->n_out + outlen); + if (tmp == NULL) { + OSSL_FUNC(EVP_MD_CTX_free)(clone); + return; + } OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx0); OSSL_FUNC(EVP_DigestFinalXOF)(clone, tmp, s->n_out + outlen); memcpy(out0, tmp + s->n_out, outlen); @@ -94,7 +122,7 @@ static void SHA3_shake128_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t * OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx3); OSSL_FUNC(EVP_DigestFinalXOF)(clone, tmp, s->n_out + outlen); memcpy(out3, tmp + s->n_out, outlen); - OQS_MEM_free(tmp); // IGNORE free-check + OQS_MEM_free(tmp); } OSSL_FUNC(EVP_MD_CTX_free)(clone); s->n_out += outlen; @@ -102,8 +130,28 @@ static void SHA3_shake128_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t * } static void SHA3_shake128_x4_inc_ctx_clone(OQS_SHA3_shake128_x4_inc_ctx *dest, const OQS_SHA3_shake128_x4_inc_ctx *src) { + if (dest == NULL || src == NULL || src->ctx == NULL) { + return; + } + dest->ctx = OQS_MEM_malloc(sizeof(intrn_shake128_x4_inc_ctx)); + if (dest->ctx == NULL) { + return; + } intrn_shake128_x4_inc_ctx *s = (intrn_shake128_x4_inc_ctx *)src->ctx; intrn_shake128_x4_inc_ctx *d = (intrn_shake128_x4_inc_ctx *)dest->ctx; + d->mdctx0 = OSSL_FUNC(EVP_MD_CTX_new)(); + d->mdctx1 = OSSL_FUNC(EVP_MD_CTX_new)(); + d->mdctx2 = OSSL_FUNC(EVP_MD_CTX_new)(); + d->mdctx3 = OSSL_FUNC(EVP_MD_CTX_new)(); + if (d->mdctx0 == NULL || d->mdctx1 == NULL || d->mdctx2 == NULL || d->mdctx3 == NULL) { + OSSL_FUNC(EVP_MD_CTX_free)(d->mdctx0); + OSSL_FUNC(EVP_MD_CTX_free)(d->mdctx1); + OSSL_FUNC(EVP_MD_CTX_free)(d->mdctx2); + OSSL_FUNC(EVP_MD_CTX_free)(d->mdctx3); + OQS_MEM_free(d); + dest->ctx = NULL; + return; + } OSSL_FUNC(EVP_MD_CTX_copy_ex)(d->mdctx0, s->mdctx0); OSSL_FUNC(EVP_MD_CTX_copy_ex)(d->mdctx1, s->mdctx1); OSSL_FUNC(EVP_MD_CTX_copy_ex)(d->mdctx2, s->mdctx2); @@ -112,15 +160,22 @@ static void SHA3_shake128_x4_inc_ctx_clone(OQS_SHA3_shake128_x4_inc_ctx *dest, c } static void SHA3_shake128_x4_inc_ctx_release(OQS_SHA3_shake128_x4_inc_ctx *state) { + if (state == NULL || state->ctx == NULL) { + return; + } intrn_shake128_x4_inc_ctx *s = (intrn_shake128_x4_inc_ctx *)state->ctx; OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx0); OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx1); OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx2); OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx3); - OQS_MEM_free(s); // IGNORE free-check + OQS_MEM_free(s); + state->ctx = NULL; } static void SHA3_shake128_x4_inc_ctx_reset(OQS_SHA3_shake128_x4_inc_ctx *state) { + if (state == NULL || state->ctx == NULL) { + return; + } intrn_shake128_x4_inc_ctx *s = (intrn_shake128_x4_inc_ctx *)state->ctx; OSSL_FUNC(EVP_MD_CTX_reset)(s->mdctx0); OSSL_FUNC(EVP_MD_CTX_reset)(s->mdctx1); @@ -152,15 +207,29 @@ typedef struct { EVP_MD_CTX *mdctx3; size_t n_out; } intrn_shake256_x4_inc_ctx; - static void SHA3_shake256_x4_inc_init(OQS_SHA3_shake256_x4_inc_ctx *state) { + if (state == NULL) { + return; + } state->ctx = OQS_MEM_malloc(sizeof(intrn_shake256_x4_inc_ctx)); + if (state->ctx == NULL) { + return; + } intrn_shake256_x4_inc_ctx *s = (intrn_shake256_x4_inc_ctx *)state->ctx; s->mdctx0 = OSSL_FUNC(EVP_MD_CTX_new)(); s->mdctx1 = OSSL_FUNC(EVP_MD_CTX_new)(); s->mdctx2 = OSSL_FUNC(EVP_MD_CTX_new)(); s->mdctx3 = OSSL_FUNC(EVP_MD_CTX_new)(); + if (s->mdctx0 == NULL || s->mdctx1 == NULL || s->mdctx2 == NULL || s->mdctx3 == NULL) { + OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx0); + OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx1); + OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx2); + OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx3); + OQS_MEM_free(s); + state->ctx = NULL; + return; + } OSSL_FUNC(EVP_DigestInit_ex)(s->mdctx0, oqs_shake256(), NULL); OSSL_FUNC(EVP_DigestInit_ex)(s->mdctx1, oqs_shake256(), NULL); OSSL_FUNC(EVP_DigestInit_ex)(s->mdctx2, oqs_shake256(), NULL); @@ -169,6 +238,9 @@ static void SHA3_shake256_x4_inc_init(OQS_SHA3_shake256_x4_inc_ctx *state) { } static void SHA3_shake256_x4_inc_absorb(OQS_SHA3_shake256_x4_inc_ctx *state, const uint8_t *in0, const uint8_t *in1, const uint8_t *in2, const uint8_t *in3, size_t inplen) { + if (state == NULL || state->ctx == NULL) { + return; + } intrn_shake256_x4_inc_ctx *s = (intrn_shake256_x4_inc_ctx *)state->ctx; OSSL_FUNC(EVP_DigestUpdate)(s->mdctx0, in0, inplen); OSSL_FUNC(EVP_DigestUpdate)(s->mdctx1, in1, inplen); @@ -181,6 +253,9 @@ static void SHA3_shake256_x4_inc_finalize(OQS_SHA3_shake256_x4_inc_ctx *state) { } static void SHA3_shake256_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, size_t outlen, OQS_SHA3_shake256_x4_inc_ctx *state) { + if (state == NULL || state->ctx == NULL) { + return; + } intrn_shake256_x4_inc_ctx *s = (intrn_shake256_x4_inc_ctx *)state->ctx; #if OPENSSL_VERSION_NUMBER >= 0x30300000L EVP_DigestSqueeze(s->mdctx0, out0, outlen); @@ -191,6 +266,9 @@ static void SHA3_shake256_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t * EVP_MD_CTX *clone; clone = OSSL_FUNC(EVP_MD_CTX_new)(); + if (clone == NULL) { + return; + } OSSL_FUNC(EVP_DigestInit_ex)(clone, oqs_shake256(), NULL); if (s->n_out == 0) { OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx0); @@ -203,6 +281,10 @@ static void SHA3_shake256_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t * OSSL_FUNC(EVP_DigestFinalXOF)(clone, out3, outlen); } else { uint8_t *tmp = OQS_MEM_checked_malloc(s->n_out + outlen); + if (tmp == NULL) { + OSSL_FUNC(EVP_MD_CTX_free)(clone); + return; + } OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx0); OSSL_FUNC(EVP_DigestFinalXOF)(clone, tmp, s->n_out + outlen); memcpy(out0, tmp + s->n_out, outlen); @@ -215,7 +297,7 @@ static void SHA3_shake256_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t * OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx3); OSSL_FUNC(EVP_DigestFinalXOF)(clone, tmp, s->n_out + outlen); memcpy(out3, tmp + s->n_out, outlen); - OQS_MEM_free(tmp); // IGNORE free-check + OQS_MEM_free(tmp); } OSSL_FUNC(EVP_MD_CTX_free)(clone); s->n_out += outlen; @@ -223,6 +305,9 @@ static void SHA3_shake256_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t * } static void SHA3_shake256_x4_inc_ctx_clone(OQS_SHA3_shake256_x4_inc_ctx *dest, const OQS_SHA3_shake256_x4_inc_ctx *src) { + if (dest == NULL || src == NULL || dest->ctx == NULL || src->ctx == NULL) { + return; + } intrn_shake256_x4_inc_ctx *s = (intrn_shake256_x4_inc_ctx *)src->ctx; intrn_shake256_x4_inc_ctx *d = (intrn_shake256_x4_inc_ctx *)dest->ctx; OSSL_FUNC(EVP_MD_CTX_copy_ex)(d->mdctx0, s->mdctx0); @@ -233,15 +318,22 @@ static void SHA3_shake256_x4_inc_ctx_clone(OQS_SHA3_shake256_x4_inc_ctx *dest, c } static void SHA3_shake256_x4_inc_ctx_release(OQS_SHA3_shake256_x4_inc_ctx *state) { + if (state == NULL || state->ctx == NULL) { + return; + } intrn_shake256_x4_inc_ctx *s = (intrn_shake256_x4_inc_ctx *)state->ctx; OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx0); OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx1); OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx2); OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx3); - OQS_MEM_free(s); // IGNORE free-check + OQS_MEM_free(s); + state->ctx = NULL; } static void SHA3_shake256_x4_inc_ctx_reset(OQS_SHA3_shake256_x4_inc_ctx *state) { + if (state == NULL || state->ctx == NULL) { + return; + } intrn_shake256_x4_inc_ctx *s = (intrn_shake256_x4_inc_ctx *)state->ctx; OSSL_FUNC(EVP_MD_CTX_reset)(s->mdctx0); OSSL_FUNC(EVP_MD_CTX_reset)(s->mdctx1); diff --git a/src/common/sha3/sha3.c b/src/common/sha3/sha3.c index 600fc19839..2cc6eaed58 100644 --- a/src/common/sha3/sha3.c +++ b/src/common/sha3/sha3.c @@ -9,153 +9,229 @@ extern struct OQS_SHA3_callbacks sha3_default_callbacks; static struct OQS_SHA3_callbacks *callbacks = &sha3_default_callbacks; OQS_API void OQS_SHA3_set_callbacks(struct OQS_SHA3_callbacks *new_callbacks) { - callbacks = new_callbacks; + if (new_callbacks != NULL) { + callbacks = new_callbacks; + } } void OQS_SHA3_sha3_256(uint8_t *output, const uint8_t *input, size_t inplen) { - callbacks->SHA3_sha3_256(output, input, inplen); + if (callbacks != NULL && output != NULL && input != NULL) { + callbacks->SHA3_sha3_256(output, input, inplen); + } } void OQS_SHA3_sha3_256_inc_init(OQS_SHA3_sha3_256_inc_ctx *state) { - callbacks->SHA3_sha3_256_inc_init(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_sha3_256_inc_init(state); + } } void OQS_SHA3_sha3_256_inc_absorb(OQS_SHA3_sha3_256_inc_ctx *state, const uint8_t *input, size_t inlen) { - callbacks->SHA3_sha3_256_inc_absorb(state, input, inlen); + if (callbacks != NULL && state != NULL && input != NULL) { + callbacks->SHA3_sha3_256_inc_absorb(state, input, inlen); + } } void OQS_SHA3_sha3_256_inc_finalize(uint8_t *output, OQS_SHA3_sha3_256_inc_ctx *state) { - callbacks->SHA3_sha3_256_inc_finalize(output, state); + if (callbacks != NULL && output != NULL && state != NULL) { + callbacks->SHA3_sha3_256_inc_finalize(output, state); + } } void OQS_SHA3_sha3_256_inc_ctx_release(OQS_SHA3_sha3_256_inc_ctx *state) { - callbacks->SHA3_sha3_256_inc_ctx_release(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_sha3_256_inc_ctx_release(state); + } } void OQS_SHA3_sha3_256_inc_ctx_reset(OQS_SHA3_sha3_256_inc_ctx *state) { - callbacks->SHA3_sha3_256_inc_ctx_reset(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_sha3_256_inc_ctx_reset(state); + } } void OQS_SHA3_sha3_256_inc_ctx_clone(OQS_SHA3_sha3_256_inc_ctx *dest, const OQS_SHA3_sha3_256_inc_ctx *src) { - callbacks->SHA3_sha3_256_inc_ctx_clone(dest, src); + if (callbacks != NULL && dest != NULL && src != NULL) { + callbacks->SHA3_sha3_256_inc_ctx_clone(dest, src); + } } void OQS_SHA3_sha3_384(uint8_t *output, const uint8_t *input, size_t inplen) { - callbacks->SHA3_sha3_384(output, input, inplen); + if (callbacks != NULL && output != NULL && input != NULL) { + callbacks->SHA3_sha3_384(output, input, inplen); + } } void OQS_SHA3_sha3_384_inc_init(OQS_SHA3_sha3_384_inc_ctx *state) { - callbacks->SHA3_sha3_384_inc_init(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_sha3_384_inc_init(state); + } } void OQS_SHA3_sha3_384_inc_absorb(OQS_SHA3_sha3_384_inc_ctx *state, const uint8_t *input, size_t inlen) { - callbacks->SHA3_sha3_384_inc_absorb(state, input, inlen); + if (callbacks != NULL && state != NULL && input != NULL) { + callbacks->SHA3_sha3_384_inc_absorb(state, input, inlen); + } } void OQS_SHA3_sha3_384_inc_finalize(uint8_t *output, OQS_SHA3_sha3_384_inc_ctx *state) { - callbacks->SHA3_sha3_384_inc_finalize(output, state); + if (callbacks != NULL && output != NULL && state != NULL) { + callbacks->SHA3_sha3_384_inc_finalize(output, state); + } } void OQS_SHA3_sha3_384_inc_ctx_release(OQS_SHA3_sha3_384_inc_ctx *state) { - callbacks->SHA3_sha3_384_inc_ctx_release(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_sha3_384_inc_ctx_release(state); + } } void OQS_SHA3_sha3_384_inc_ctx_reset(OQS_SHA3_sha3_384_inc_ctx *state) { - callbacks->SHA3_sha3_384_inc_ctx_reset(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_sha3_384_inc_ctx_reset(state); + } } void OQS_SHA3_sha3_384_inc_ctx_clone(OQS_SHA3_sha3_384_inc_ctx *dest, const OQS_SHA3_sha3_384_inc_ctx *src) { - callbacks->SHA3_sha3_384_inc_ctx_clone(dest, src); + if (callbacks != NULL && dest != NULL && src != NULL) { + callbacks->SHA3_sha3_384_inc_ctx_clone(dest, src); + } } void OQS_SHA3_sha3_512(uint8_t *output, const uint8_t *input, size_t inplen) { - callbacks->SHA3_sha3_512(output, input, inplen); + if (callbacks != NULL && output != NULL && input != NULL) { + callbacks->SHA3_sha3_512(output, input, inplen); + } } void OQS_SHA3_sha3_512_inc_init(OQS_SHA3_sha3_512_inc_ctx *state) { - callbacks->SHA3_sha3_512_inc_init(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_sha3_512_inc_init(state); + } } void OQS_SHA3_sha3_512_inc_absorb(OQS_SHA3_sha3_512_inc_ctx *state, const uint8_t *input, size_t inlen) { - callbacks->SHA3_sha3_512_inc_absorb(state, input, inlen); + if (callbacks != NULL && state != NULL && input != NULL) { + callbacks->SHA3_sha3_512_inc_absorb(state, input, inlen); + } } void OQS_SHA3_sha3_512_inc_finalize(uint8_t *output, OQS_SHA3_sha3_512_inc_ctx *state) { - callbacks->SHA3_sha3_512_inc_finalize(output, state); + if (callbacks != NULL && output != NULL && state != NULL) { + callbacks->SHA3_sha3_512_inc_finalize(output, state); + } } void OQS_SHA3_sha3_512_inc_ctx_release(OQS_SHA3_sha3_512_inc_ctx *state) { - callbacks->SHA3_sha3_512_inc_ctx_release(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_sha3_512_inc_ctx_release(state); + } } void OQS_SHA3_sha3_512_inc_ctx_reset(OQS_SHA3_sha3_512_inc_ctx *state) { - callbacks->SHA3_sha3_512_inc_ctx_reset(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_sha3_512_inc_ctx_reset(state); + } } void OQS_SHA3_sha3_512_inc_ctx_clone(OQS_SHA3_sha3_512_inc_ctx *dest, const OQS_SHA3_sha3_512_inc_ctx *src) { - callbacks->SHA3_sha3_512_inc_ctx_clone(dest, src); + if (callbacks != NULL && dest != NULL && src != NULL) { + callbacks->SHA3_sha3_512_inc_ctx_clone(dest, src); + } } void OQS_SHA3_shake128(uint8_t *output, size_t outlen, const uint8_t *input, size_t inplen) { - callbacks->SHA3_shake128(output, outlen, input, inplen); + if (callbacks != NULL && output != NULL && input != NULL) { + callbacks->SHA3_shake128(output, outlen, input, inplen); + } } void OQS_SHA3_shake128_inc_init(OQS_SHA3_shake128_inc_ctx *state) { - callbacks->SHA3_shake128_inc_init(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake128_inc_init(state); + } } void OQS_SHA3_shake128_inc_absorb(OQS_SHA3_shake128_inc_ctx *state, const uint8_t *input, size_t inlen) { - callbacks->SHA3_shake128_inc_absorb(state, input, inlen); + if (callbacks != NULL && state != NULL && input != NULL) { + callbacks->SHA3_shake128_inc_absorb(state, input, inlen); + } } void OQS_SHA3_shake128_inc_finalize(OQS_SHA3_shake128_inc_ctx *state) { - callbacks->SHA3_shake128_inc_finalize(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake128_inc_finalize(state); + } } void OQS_SHA3_shake128_inc_squeeze(uint8_t *output, size_t outlen, OQS_SHA3_shake128_inc_ctx *state) { - callbacks->SHA3_shake128_inc_squeeze(output, outlen, state); + if (callbacks != NULL && output != NULL && state != NULL) { + callbacks->SHA3_shake128_inc_squeeze(output, outlen, state); + } } void OQS_SHA3_shake128_inc_ctx_release(OQS_SHA3_shake128_inc_ctx *state) { - callbacks->SHA3_shake128_inc_ctx_release(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake128_inc_ctx_release(state); + } } void OQS_SHA3_shake128_inc_ctx_clone(OQS_SHA3_shake128_inc_ctx *dest, const OQS_SHA3_shake128_inc_ctx *src) { - callbacks->SHA3_shake128_inc_ctx_clone(dest, src); + if (callbacks != NULL && dest != NULL && src != NULL) { + callbacks->SHA3_shake128_inc_ctx_clone(dest, src); + } } void OQS_SHA3_shake128_inc_ctx_reset(OQS_SHA3_shake128_inc_ctx *state) { - callbacks->SHA3_shake128_inc_ctx_reset(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake128_inc_ctx_reset(state); + } } void OQS_SHA3_shake256(uint8_t *output, size_t outlen, const uint8_t *input, size_t inplen) { - callbacks->SHA3_shake256(output, outlen, input, inplen); + if (callbacks != NULL && output != NULL && input != NULL) { + callbacks->SHA3_shake256(output, outlen, input, inplen); + } } void OQS_SHA3_shake256_inc_init(OQS_SHA3_shake256_inc_ctx *state) { - callbacks->SHA3_shake256_inc_init(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake256_inc_init(state); + } } void OQS_SHA3_shake256_inc_absorb(OQS_SHA3_shake256_inc_ctx *state, const uint8_t *input, size_t inlen) { - callbacks->SHA3_shake256_inc_absorb(state, input, inlen); + if (callbacks != NULL && state != NULL && input != NULL) { + callbacks->SHA3_shake256_inc_absorb(state, input, inlen); + } } void OQS_SHA3_shake256_inc_finalize(OQS_SHA3_shake256_inc_ctx *state) { - callbacks->SHA3_shake256_inc_finalize(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake256_inc_finalize(state); + } } void OQS_SHA3_shake256_inc_squeeze(uint8_t *output, size_t outlen, OQS_SHA3_shake256_inc_ctx *state) { - callbacks->SHA3_shake256_inc_squeeze(output, outlen, state); + if (callbacks != NULL && output != NULL && state != NULL) { + callbacks->SHA3_shake256_inc_squeeze(output, outlen, state); + } } void OQS_SHA3_shake256_inc_ctx_release(OQS_SHA3_shake256_inc_ctx *state) { - callbacks->SHA3_shake256_inc_ctx_release(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake256_inc_ctx_release(state); + } } void OQS_SHA3_shake256_inc_ctx_clone(OQS_SHA3_shake256_inc_ctx *dest, const OQS_SHA3_shake256_inc_ctx *src) { - callbacks->SHA3_shake256_inc_ctx_clone(dest, src); + if (callbacks != NULL && dest != NULL && src != NULL) { + callbacks->SHA3_shake256_inc_ctx_clone(dest, src); + } } void OQS_SHA3_shake256_inc_ctx_reset(OQS_SHA3_shake256_inc_ctx *state) { - callbacks->SHA3_shake256_inc_ctx_reset(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake256_inc_ctx_reset(state); + } } diff --git a/src/common/sha3/sha3x4.c b/src/common/sha3/sha3x4.c index b41ea08d3f..c526b7aa96 100644 --- a/src/common/sha3/sha3x4.c +++ b/src/common/sha3/sha3x4.c @@ -9,69 +9,103 @@ extern struct OQS_SHA3_x4_callbacks sha3_x4_default_callbacks; static struct OQS_SHA3_x4_callbacks *callbacks = &sha3_x4_default_callbacks; OQS_API void OQS_SHA3_x4_set_callbacks(struct OQS_SHA3_x4_callbacks *new_callbacks) { - callbacks = new_callbacks; + if (new_callbacks != NULL) { + callbacks = new_callbacks; + } } void OQS_SHA3_shake128_x4(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, size_t outlen, const uint8_t *in0, const uint8_t *in1, const uint8_t *in2, const uint8_t *in3, size_t inlen) { - callbacks->SHA3_shake128_x4(out0, out1, out2, out3, outlen, in0, in1, in2, in3, inlen); + if (callbacks != NULL) { + callbacks->SHA3_shake128_x4(out0, out1, out2, out3, outlen, in0, in1, in2, in3, inlen); + } } void OQS_SHA3_shake128_x4_inc_init(OQS_SHA3_shake128_x4_inc_ctx *state) { - callbacks->SHA3_shake128_x4_inc_init(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake128_x4_inc_init(state); + } } void OQS_SHA3_shake128_x4_inc_absorb(OQS_SHA3_shake128_x4_inc_ctx *state, const uint8_t *in0, const uint8_t *in1, const uint8_t *in2, const uint8_t *in3, size_t inlen) { - callbacks->SHA3_shake128_x4_inc_absorb(state, in0, in1, in2, in3, inlen); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake128_x4_inc_absorb(state, in0, in1, in2, in3, inlen); + } } void OQS_SHA3_shake128_x4_inc_finalize(OQS_SHA3_shake128_x4_inc_ctx *state) { - callbacks->SHA3_shake128_x4_inc_finalize(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake128_x4_inc_finalize(state); + } } void OQS_SHA3_shake128_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, size_t outlen, OQS_SHA3_shake128_x4_inc_ctx *state) { - callbacks->SHA3_shake128_x4_inc_squeeze(out0, out1, out2, out3, outlen, state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake128_x4_inc_squeeze(out0, out1, out2, out3, outlen, state); + } } void OQS_SHA3_shake128_x4_inc_ctx_release(OQS_SHA3_shake128_x4_inc_ctx *state) { - callbacks->SHA3_shake128_x4_inc_ctx_release(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake128_x4_inc_ctx_release(state); + } } void OQS_SHA3_shake128_x4_inc_ctx_clone(OQS_SHA3_shake128_x4_inc_ctx *dest, const OQS_SHA3_shake128_x4_inc_ctx *src) { - callbacks->SHA3_shake128_x4_inc_ctx_clone(dest, src); + if (callbacks != NULL && dest != NULL && src != NULL) { + callbacks->SHA3_shake128_x4_inc_ctx_clone(dest, src); + } } void OQS_SHA3_shake128_x4_inc_ctx_reset(OQS_SHA3_shake128_x4_inc_ctx *state) { - callbacks->SHA3_shake128_x4_inc_ctx_reset(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake128_x4_inc_ctx_reset(state); + } } void OQS_SHA3_shake256_x4(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, size_t outlen, const uint8_t *in0, const uint8_t *in1, const uint8_t *in2, const uint8_t *in3, size_t inlen) { - callbacks->SHA3_shake256_x4(out0, out1, out2, out3, outlen, in0, in1, in2, in3, inlen); + if (callbacks != NULL) { + callbacks->SHA3_shake256_x4(out0, out1, out2, out3, outlen, in0, in1, in2, in3, inlen); + } } void OQS_SHA3_shake256_x4_inc_init(OQS_SHA3_shake256_x4_inc_ctx *state) { - callbacks->SHA3_shake256_x4_inc_init(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake256_x4_inc_init(state); + } } void OQS_SHA3_shake256_x4_inc_absorb(OQS_SHA3_shake256_x4_inc_ctx *state, const uint8_t *in0, const uint8_t *in1, const uint8_t *in2, const uint8_t *in3, size_t inlen) { - callbacks->SHA3_shake256_x4_inc_absorb(state, in0, in1, in2, in3, inlen); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake256_x4_inc_absorb(state, in0, in1, in2, in3, inlen); + } } void OQS_SHA3_shake256_x4_inc_finalize(OQS_SHA3_shake256_x4_inc_ctx *state) { - callbacks->SHA3_shake256_x4_inc_finalize(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake256_x4_inc_finalize(state); + } } void OQS_SHA3_shake256_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, size_t outlen, OQS_SHA3_shake256_x4_inc_ctx *state) { - callbacks->SHA3_shake256_x4_inc_squeeze(out0, out1, out2, out3, outlen, state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake256_x4_inc_squeeze(out0, out1, out2, out3, outlen, state); + } } void OQS_SHA3_shake256_x4_inc_ctx_release(OQS_SHA3_shake256_x4_inc_ctx *state) { - callbacks->SHA3_shake256_x4_inc_ctx_release(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake256_x4_inc_ctx_release(state); + } } void OQS_SHA3_shake256_x4_inc_ctx_clone(OQS_SHA3_shake256_x4_inc_ctx *dest, const OQS_SHA3_shake256_x4_inc_ctx *src) { - callbacks->SHA3_shake256_x4_inc_ctx_clone(dest, src); + if (callbacks != NULL && dest != NULL && src != NULL) { + callbacks->SHA3_shake256_x4_inc_ctx_clone(dest, src); + } } void OQS_SHA3_shake256_x4_inc_ctx_reset(OQS_SHA3_shake256_x4_inc_ctx *state) { - callbacks->SHA3_shake256_x4_inc_ctx_reset(state); + if (callbacks != NULL && state != NULL) { + callbacks->SHA3_shake256_x4_inc_ctx_reset(state); + } } diff --git a/src/common/sha3/xkcp_low/KeccakP-1600/plain-64bits/KeccakP-1600-opt64.c b/src/common/sha3/xkcp_low/KeccakP-1600/plain-64bits/KeccakP-1600-opt64.c index d813a3679b..25f0479e15 100644 --- a/src/common/sha3/xkcp_low/KeccakP-1600/plain-64bits/KeccakP-1600-opt64.c +++ b/src/common/sha3/xkcp_low/KeccakP-1600/plain-64bits/KeccakP-1600-opt64.c @@ -83,6 +83,9 @@ void KeccakP1600_StaticInitialize(void) { } /* ---------------------------------------------------------------- */ void KeccakP1600_Initialize(void *state) { + if (state == NULL) { + return; + } memset(state, 0, 200); ((uint64_t *)state)[ 1] = ~(uint64_t)0; ((uint64_t *)state)[ 2] = ~(uint64_t)0; @@ -95,6 +98,9 @@ void KeccakP1600_Initialize(void *state) { /* ---------------------------------------------------------------- */ void KeccakP1600_AddBytesInLane(void *state, unsigned int lanePosition, const unsigned char *data, unsigned int offset, unsigned int length) { + if (state == NULL || data == NULL) { + return; + } #if (PLATFORM_BYTE_ORDER == IS_LITTLE_ENDIAN) uint64_t lane; if (length == 0) { @@ -120,6 +126,9 @@ void KeccakP1600_AddBytesInLane(void *state, unsigned int lanePosition, const un /* ---------------------------------------------------------------- */ void KeccakP1600_AddLanes(void *state, const unsigned char *data, unsigned int laneCount) { + if (state == NULL || data == NULL) { + return; + } #if (PLATFORM_BYTE_ORDER == IS_LITTLE_ENDIAN) unsigned int i = 0; /* If either pointer is misaligned, fall back to byte-wise xor. */ @@ -174,10 +183,16 @@ void KeccakP1600_AddLanes(void *state, const unsigned char *data, unsigned int l #if (PLATFORM_BYTE_ORDER == IS_LITTLE_ENDIAN) void KeccakP1600_AddByte(void *state, unsigned char byte, unsigned int offset) { + if (state == NULL) { + return; + } ((unsigned char *)state)[offset] ^= byte; } #else void KeccakP1600_AddByte(void *state, unsigned char byte, unsigned int offset) { + if (state == NULL) { + return; + } uint64_t lane = byte; lane <<= (offset % 8) * 8; ((uint64_t *)state)[offset / 8] ^= lane; @@ -187,12 +202,18 @@ void KeccakP1600_AddByte(void *state, unsigned char byte, unsigned int offset) { /* ---------------------------------------------------------------- */ void KeccakP1600_AddBytes(void *state, const unsigned char *data, unsigned int offset, unsigned int length) { + if (state == NULL || data == NULL) { + return; + } SnP_AddBytes(state, data, offset, length, KeccakP1600_AddLanes, KeccakP1600_AddBytesInLane, 8); } /* ---------------------------------------------------------------- */ void KeccakP1600_OverwriteBytesInLane(void *state, unsigned int lanePosition, const unsigned char *data, unsigned int offset, unsigned int length) { + if (state == NULL || data == NULL) { + return; + } #if (PLATFORM_BYTE_ORDER == IS_LITTLE_ENDIAN) if ((lanePosition == 1) || (lanePosition == 2) || (lanePosition == 8) || (lanePosition == 12) || (lanePosition == 17) || (lanePosition == 20)) { unsigned int i; @@ -216,15 +237,16 @@ void KeccakP1600_OverwriteBytesInLane(void *state, unsigned int lanePosition, co ((uint64_t *)state)[lanePosition] = lane; #endif } - /* ---------------------------------------------------------------- */ - void KeccakP1600_OverwriteLanes(void *state, const unsigned char *data, unsigned int laneCount) { + if (state == NULL || data == NULL) { + return; + } #if (PLATFORM_BYTE_ORDER == IS_LITTLE_ENDIAN) unsigned int lanePosition; for (lanePosition = 0; lanePosition < laneCount; lanePosition++) { - memcpy(((uint64_t *)state) + lanePosition, data, sizeof(uint64_t)); + memcpy(((uint64_t *)state) + lanePosition, data + (lanePosition * sizeof(uint64_t)), sizeof(uint64_t)); if ((lanePosition == 1) || (lanePosition == 2) || (lanePosition == 8) || (lanePosition == 12) || (lanePosition == 17) || (lanePosition == 20)) { ((uint64_t *)state)[lanePosition] = ~((uint64_t *)state)[lanePosition]; } @@ -253,12 +275,18 @@ void KeccakP1600_OverwriteLanes(void *state, const unsigned char *data, unsigned /* ---------------------------------------------------------------- */ void KeccakP1600_OverwriteBytes(void *state, const unsigned char *data, unsigned int offset, unsigned int length) { + if (state == NULL || data == NULL) { + return; + } SnP_OverwriteBytes(state, data, offset, length, KeccakP1600_OverwriteLanes, KeccakP1600_OverwriteBytesInLane, 8); } /* ---------------------------------------------------------------- */ void KeccakP1600_OverwriteWithZeroes(void *state, unsigned int byteCount) { + if (state == NULL) { + return; + } #if (PLATFORM_BYTE_ORDER == IS_LITTLE_ENDIAN) unsigned int lanePosition; @@ -304,6 +332,9 @@ void KeccakP1600_OverwriteWithZeroes(void *state, unsigned int byteCount) { /* ---------------------------------------------------------------- */ void KeccakP1600_Permute_Nrounds(void *state, unsigned int nr) { + if (state == NULL) { + return; + } declareABCDE unsigned int i; uint64_t *stateAsLanes = (uint64_t *)state; @@ -311,12 +342,14 @@ void KeccakP1600_Permute_Nrounds(void *state, unsigned int nr) { copyFromState(A, stateAsLanes) roundsN(nr) copyToState(stateAsLanes, A) - } /* ---------------------------------------------------------------- */ void KeccakP1600_Permute_24rounds(void *state) { + if (state == NULL) { + return; + } declareABCDE uint64_t *stateAsLanes = (uint64_t *)state; @@ -328,6 +361,9 @@ void KeccakP1600_Permute_24rounds(void *state) { /* ---------------------------------------------------------------- */ void KeccakP1600_Permute_12rounds(void *state) { + if (state == NULL) { + return; + } declareABCDE uint64_t *stateAsLanes = (uint64_t *)state; @@ -339,6 +375,9 @@ void KeccakP1600_Permute_12rounds(void *state) { /* ---------------------------------------------------------------- */ void KeccakP1600_ExtractBytesInLane(const void *state, unsigned int lanePosition, unsigned char *data, unsigned int offset, unsigned int length) { + if (state == NULL || data == NULL) { + return; + } uint64_t lane = ((const uint64_t *)state)[lanePosition]; if ((lanePosition == 1) || (lanePosition == 2) || (lanePosition == 8) || (lanePosition == 12) || (lanePosition == 17) || (lanePosition == 20)) { lane = ~lane; @@ -372,6 +411,9 @@ static void fromWordToBytes(uint8_t *bytes, const uint64_t word) { #endif void KeccakP1600_ExtractLanes(const void *state, unsigned char *data, unsigned int laneCount) { + if (state == NULL || data == NULL) { + return; + } #if (PLATFORM_BYTE_ORDER == IS_LITTLE_ENDIAN) memcpy(data, state, laneCount * 8); #else @@ -404,12 +446,18 @@ void KeccakP1600_ExtractLanes(const void *state, unsigned char *data, unsigned i /* ---------------------------------------------------------------- */ void KeccakP1600_ExtractBytes(const void *state, unsigned char *data, unsigned int offset, unsigned int length) { + if (state == NULL || data == NULL) { + return; + } SnP_ExtractBytes(state, data, offset, length, KeccakP1600_ExtractLanes, KeccakP1600_ExtractBytesInLane, 8); } /* ---------------------------------------------------------------- */ void KeccakP1600_ExtractAndAddBytesInLane(const void *state, unsigned int lanePosition, const unsigned char *input, unsigned char *output, unsigned int offset, unsigned int length) { + if (state == NULL || input == NULL || output == NULL) { + return; + } uint64_t lane = ((const uint64_t *)state)[lanePosition]; if ((lanePosition == 1) || (lanePosition == 2) || (lanePosition == 8) || (lanePosition == 12) || (lanePosition == 17) || (lanePosition == 20)) { lane = ~lane; @@ -432,8 +480,11 @@ void KeccakP1600_ExtractAndAddBytesInLane(const void *state, unsigned int lanePo } /* ---------------------------------------------------------------- */ - void KeccakP1600_ExtractAndAddLanes(const void *state, const unsigned char *input, unsigned char *output, unsigned int laneCount) { + if (state == NULL || input == NULL || output == NULL) { + return; + } + unsigned int i; #if (PLATFORM_BYTE_ORDER != IS_LITTLE_ENDIAN) unsigned char temp[8]; @@ -479,12 +530,18 @@ void KeccakP1600_ExtractAndAddLanes(const void *state, const unsigned char *inpu /* ---------------------------------------------------------------- */ void KeccakP1600_ExtractAndAddBytes(const void *state, const unsigned char *input, unsigned char *output, unsigned int offset, unsigned int length) { + if (state == NULL || input == NULL || output == NULL) { + return; + } SnP_ExtractAndAddBytes(state, input, output, offset, length, KeccakP1600_ExtractAndAddLanes, KeccakP1600_ExtractAndAddBytesInLane, 8); } /* ---------------------------------------------------------------- */ size_t KeccakF1600_FastLoop_Absorb(void *state, unsigned int laneCount, const unsigned char *data, size_t dataByteLen) { + if (state == NULL || data == NULL) { + return 0; + } size_t originalDataByteLen = dataByteLen; declareABCDE uint64_t *stateAsLanes = (uint64_t *)state; @@ -503,6 +560,9 @@ size_t KeccakF1600_FastLoop_Absorb(void *state, unsigned int laneCount, const un /* ---------------------------------------------------------------- */ size_t KeccakP1600_12rounds_FastLoop_Absorb(void *state, unsigned int laneCount, const unsigned char *data, size_t dataByteLen) { + if (state == NULL || data == NULL) { + return 0; + } size_t originalDataByteLen = dataByteLen; declareABCDE uint64_t *stateAsLanes = (uint64_t *)state; diff --git a/src/common/sha3/xkcp_low/KeccakP-1600times4/avx2/KeccakP-1600-times4-SIMD256.c b/src/common/sha3/xkcp_low/KeccakP-1600times4/avx2/KeccakP-1600-times4-SIMD256.c index 169fe97716..6205d545d3 100644 --- a/src/common/sha3/xkcp_low/KeccakP-1600times4/avx2/KeccakP-1600-times4-SIMD256.c +++ b/src/common/sha3/xkcp_low/KeccakP-1600times4/avx2/KeccakP-1600-times4-SIMD256.c @@ -91,6 +91,9 @@ static const union { #define SnP_laneLengthInBytes 8 static inline uint64_t load64(const unsigned char *x) { + if (x == NULL) { + return 0; + } return (uint64_t) x[0] \ | (uint64_t) x[1] << 0x08 \ | (uint64_t) x[2] << 0x10 \ @@ -102,18 +105,28 @@ static inline uint64_t load64(const unsigned char *x) { } static void store64(unsigned char *out, uint64_t in) { - memcpy(out, &in, sizeof(uint64_t)); + if (out != NULL) { + memcpy(out, &in, sizeof(uint64_t)); + } } void KeccakP1600times4_InitializeAll(void *states) { - memset(states, 0, KeccakP1600times4_statesSizeInBytes_avx2); + if (states != NULL) { + memset(states, 0, KeccakP1600times4_statesSizeInBytes_avx2); + } } void KeccakP1600times4_AddByte(void *states, unsigned int instanceIndex, unsigned char byte, unsigned int offset) { - ((unsigned char *)states)[instanceIndex * 8 + (offset / 8) * 4 * 8 + offset % 8] ^= byte; + if (states != NULL) { + ((unsigned char *)states)[instanceIndex * 8 + (offset / 8) * 4 * 8 + offset % 8] ^= byte; + } } void KeccakP1600times4_AddBytes(void *states, unsigned int instanceIndex, const unsigned char *data, unsigned int offset, unsigned int length) { + if (states == NULL || data == NULL) { + return; + } + unsigned int sizeLeft = length; unsigned int lanePosition = offset / SnP_laneLengthInBytes; unsigned int offsetInLane = offset % SnP_laneLengthInBytes; @@ -151,6 +164,10 @@ void KeccakP1600times4_AddBytes(void *states, unsigned int instanceIndex, const } void KeccakP1600times4_AddLanesAll(void *states, const unsigned char *data, unsigned int laneCount, unsigned int laneOffset) { + if (states == NULL || data == NULL) { + return; + } + V256 *stateAsLanes = (V256 *)states; unsigned int i; const unsigned char *curData0 = data; @@ -199,6 +216,10 @@ void KeccakP1600times4_AddLanesAll(void *states, const unsigned char *data, unsi } void KeccakP1600times4_OverwriteBytes(void *states, unsigned int instanceIndex, const unsigned char *data, unsigned int offset, unsigned int length) { + if (states == NULL || data == NULL) { + return; + } + unsigned int sizeLeft = length; unsigned int lanePosition = offset / SnP_laneLengthInBytes; unsigned int offsetInLane = offset % SnP_laneLengthInBytes; @@ -228,8 +249,11 @@ void KeccakP1600times4_OverwriteBytes(void *states, unsigned int instanceIndex, memcpy(&statesAsLanes[laneIndex(instanceIndex, lanePosition)], curData, sizeLeft); } } - void KeccakP1600times4_OverwriteLanesAll(void *states, const unsigned char *data, unsigned int laneCount, unsigned int laneOffset) { + if (states == NULL || data == NULL) { + return; + } + V256 *stateAsLanes = (V256 *)states; unsigned int i; const unsigned char *curData0 = data; @@ -278,6 +302,10 @@ void KeccakP1600times4_OverwriteLanesAll(void *states, const unsigned char *data } void KeccakP1600times4_OverwriteWithZeroes(void *states, unsigned int instanceIndex, unsigned int byteCount) { + if (states == NULL) { + return; + } + unsigned int sizeLeft = byteCount; unsigned int lanePosition = 0; uint64_t *statesAsLanes = (uint64_t *)states; @@ -294,6 +322,10 @@ void KeccakP1600times4_OverwriteWithZeroes(void *states, unsigned int instanceIn } void KeccakP1600times4_ExtractBytes(const void *states, unsigned int instanceIndex, unsigned char *data, unsigned int offset, unsigned int length) { + if (states == NULL || data == NULL) { + return; + } + unsigned int sizeLeft = length; unsigned int lanePosition = offset / SnP_laneLengthInBytes; unsigned int offsetInLane = offset % SnP_laneLengthInBytes; @@ -324,6 +356,10 @@ void KeccakP1600times4_ExtractBytes(const void *states, unsigned int instanceInd } void KeccakP1600times4_ExtractLanesAll(const void *states, unsigned char *data, unsigned int laneCount, unsigned int laneOffset) { + if (states == NULL || data == NULL) { + return; + } + unsigned char *curData0 = data; unsigned char *curData1 = data + laneOffset * 1 * SnP_laneLengthInBytes; unsigned char *curData2 = data + laneOffset * 2 * SnP_laneLengthInBytes; @@ -374,6 +410,10 @@ void KeccakP1600times4_ExtractLanesAll(const void *states, unsigned char *data, } void KeccakP1600times4_ExtractAndAddBytes(const void *states, unsigned int instanceIndex, const unsigned char *input, unsigned char *output, unsigned int offset, unsigned int length) { + if (states == NULL || input == NULL || output == NULL) { + return; + } + unsigned int sizeLeft = length; unsigned int lanePosition = offset / SnP_laneLengthInBytes; unsigned int offsetInLane = offset % SnP_laneLengthInBytes; @@ -413,6 +453,10 @@ void KeccakP1600times4_ExtractAndAddBytes(const void *states, unsigned int insta } void KeccakP1600times4_ExtractAndAddLanesAll(const void *states, const unsigned char *input, unsigned char *output, unsigned int laneCount, unsigned int laneOffset) { + if (states == NULL || input == NULL || output == NULL) { + return; + } + const unsigned char *curInput0 = input; const unsigned char *curInput1 = input + laneOffset * 1 * SnP_laneLengthInBytes; const unsigned char *curInput2 = input + laneOffset * 2 * SnP_laneLengthInBytes; @@ -474,7 +518,6 @@ void KeccakP1600times4_ExtractAndAddLanesAll(const void *states, const unsigned #undef ExtrXor #undef ExtrXor4 } - #define declareABCDE \ V256 Aba, Abe, Abi, Abo, Abu; \ V256 Aga, Age, Agi, Ago, Agu; \ @@ -829,8 +872,10 @@ static ALIGN(KeccakP1600times4_statesAlignment_avx2) const uint64_t KeccakF1600R #define FullUnrolling #include "KeccakP-1600-unrolling.macros" - void KeccakP1600times4_PermuteAll_24rounds(void *states) { + if (states == NULL) { + return; + } V256 *statesAsLanes = (V256 *)states; declareABCDE @@ -840,6 +885,9 @@ void KeccakP1600times4_PermuteAll_24rounds(void *states) { } void KeccakP1600times4_PermuteAll_12rounds(void *states) { + if (states == NULL) { + return; + } V256 *statesAsLanes = (V256 *)states; declareABCDE @@ -849,6 +897,9 @@ void KeccakP1600times4_PermuteAll_12rounds(void *states) { } void KeccakP1600times4_PermuteAll_6rounds(void *states) { + if (states == NULL) { + return; + } V256 *statesAsLanes = (V256 *)states; declareABCDE @@ -858,6 +909,9 @@ void KeccakP1600times4_PermuteAll_6rounds(void *states) { } void KeccakP1600times4_PermuteAll_4rounds(void *states) { + if (states == NULL) { + return; + } V256 *statesAsLanes = (V256 *)states; declareABCDE @@ -867,6 +921,9 @@ void KeccakP1600times4_PermuteAll_4rounds(void *states) { } size_t KeccakF1600times4_FastLoop_Absorb(void *states, unsigned int laneCount, unsigned int laneOffsetParallel, unsigned int laneOffsetSerial, const unsigned char *data, size_t dataByteLen) { + if (states == NULL || data == NULL) { + return 0; + } if (laneCount == 21) { const unsigned char *curData0 = data; const unsigned char *curData1 = data + laneOffsetParallel * 1 * SnP_laneLengthInBytes; @@ -924,6 +981,9 @@ size_t KeccakF1600times4_FastLoop_Absorb(void *states, unsigned int laneCount, u } size_t KeccakP1600times4_12rounds_FastLoop_Absorb(void *states, unsigned int laneCount, unsigned int laneOffsetParallel, unsigned int laneOffsetSerial, const unsigned char *data, size_t dataByteLen) { + if (states == NULL || data == NULL) { + return 0; + } if (laneCount == 21) { const unsigned char *curData0 = data; const unsigned char *curData1 = data + laneOffsetParallel * 1 * SnP_laneLengthInBytes; diff --git a/src/common/sha3/xkcp_low/KeccakP-1600times4/serial/PlSnP-Fallback.inc b/src/common/sha3/xkcp_low/KeccakP-1600times4/serial/PlSnP-Fallback.inc index 7006c126bc..cb0e05f260 100644 --- a/src/common/sha3/xkcp_low/KeccakP-1600times4/serial/PlSnP-Fallback.inc +++ b/src/common/sha3/xkcp_low/KeccakP-1600times4/serial/PlSnP-Fallback.inc @@ -70,7 +70,6 @@ Please refer to PlSnP-documentation.h for more details. #define SnP_ExtractLanesAll JOIN(SnP_prefix, ExtractLanesAll, SnP_suffix) #define SnP_ExtractAndAddBytes JOIN(SnP_prefix, ExtractAndAddBytes, SnP_suffix) #define SnP_ExtractAndAddLanesAll JOIN(SnP_prefix, ExtractAndAddLanesAll, SnP_suffix) - void PlSnP_StaticInitialize( void ) { SnP_StaticInitialize(); @@ -80,6 +79,10 @@ void PlSnP_InitializeAll(void *states) { unsigned int i; + if (states == NULL) { + return; + } + for(i=0; i= c) { @@ -142,41 +148,47 @@ static void keccak_inc_absorb(uint64_t *s, uint32_t r, const uint8_t *m, } /************************************************* - * Name: keccak_inc_finalize - * - * Description: Finalizes Keccak absorb phase, prepares for squeezing - * - * Arguments: - uint64_t *s: pointer to input/output incremental state - * First 25 values represent Keccak state. - * 26th value represents either the number of absorbed bytes - * that have not been permuted, or not-yet-squeezed bytes. - * - uint32_t r: rate in bytes (e.g., 168 for SHAKE128) - * - uint8_t p: domain-separation byte for different - * Keccak-derived functions - **************************************************/ + * Name: keccak_inc_finalize + * + * Description: Finalizes Keccak absorb phase, prepares for squeezing + * + * Arguments: - uint64_t *s: pointer to input/output incremental state + * First 25 values represent Keccak state. + * 26th value represents either the number of absorbed bytes + * that have not been permuted, or not-yet-squeezed bytes. + * - uint32_t r: rate in bytes (e.g., 168 for SHAKE128) + * - uint8_t p: domain-separation byte for different + * Keccak-derived functions + **************************************************/ static void keccak_inc_finalize(uint64_t *s, uint32_t r, uint8_t p) { + if (s == NULL) { + return; + } /* After keccak_inc_absorb, we are guaranteed that s[25] < r, - so we can always use one more byte for p in the current state. */ + so we can always use one more byte for p in the current state. */ (*Keccak_AddByte_ptr)(s, p, (unsigned int)s[25]); (*Keccak_AddByte_ptr)(s, 0x80, (unsigned int)(r - 1)); s[25] = 0; } /************************************************* - * Name: keccak_inc_squeeze - * - * Description: Incremental Keccak squeeze; can be called on byte-level - * - * Arguments: - uint8_t *h: pointer to output bytes - * - size_t outlen: number of bytes to be squeezed - * - uint64_t *s: pointer to input/output incremental state - * First 25 values represent Keccak state. - * 26th value represents either the number of absorbed bytes - * that have not been permuted, or not-yet-squeezed bytes. - * - uint32_t r: rate in bytes (e.g., 168 for SHAKE128) - **************************************************/ + * Name: keccak_inc_squeeze + * + * Description: Incremental Keccak squeeze; can be called on byte-level + * + * Arguments: - uint8_t *h: pointer to output bytes + * - size_t outlen: number of bytes to be squeezed + * - uint64_t *s: pointer to input/output incremental state + * First 25 values represent Keccak state. + * 26th value represents either the number of absorbed bytes + * that have not been permuted, or not-yet-squeezed bytes. + * - uint32_t r: rate in bytes (e.g., 168 for SHAKE128) + **************************************************/ static void keccak_inc_squeeze(uint8_t *h, size_t outlen, uint64_t *s, uint32_t r) { + if (h == NULL || s == NULL) { + return; + } while (outlen > s[25]) { (*Keccak_ExtractBytes_ptr)(s, h, (unsigned int)(r - s[25]), (unsigned int)s[25]); (*Keccak_Permute_ptr)(s); @@ -193,190 +205,342 @@ static void keccak_inc_squeeze(uint8_t *h, size_t outlen, static void SHA3_sha3_256(uint8_t *output, const uint8_t *input, size_t inlen) { OQS_SHA3_sha3_256_inc_ctx s; OQS_SHA3_sha3_256_inc_init(&s); - OQS_SHA3_sha3_256_inc_absorb(&s, input, inlen); - OQS_SHA3_sha3_256_inc_finalize(output, &s); - OQS_SHA3_sha3_256_inc_ctx_release(&s); + if (s.ctx != NULL) { + OQS_SHA3_sha3_256_inc_absorb(&s, input, inlen); + OQS_SHA3_sha3_256_inc_finalize(output, &s); + OQS_SHA3_sha3_256_inc_ctx_release(&s); + } } static void SHA3_sha3_256_inc_init(OQS_SHA3_sha3_256_inc_ctx *state) { + if (state == NULL) { + return; + } state->ctx = OQS_MEM_checked_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES); - keccak_inc_reset((uint64_t *)state->ctx); + if (state->ctx != NULL) { + keccak_inc_reset((uint64_t *)state->ctx); + } } static void SHA3_sha3_256_inc_absorb(OQS_SHA3_sha3_256_inc_ctx *state, const uint8_t *input, size_t inlen) { + if (state == NULL || state->ctx == NULL || input == NULL) { + return; + } keccak_inc_absorb((uint64_t *)state->ctx, OQS_SHA3_SHA3_256_RATE, input, inlen); } static void SHA3_sha3_256_inc_finalize(uint8_t *output, OQS_SHA3_sha3_256_inc_ctx *state) { + if (output == NULL || state == NULL || state->ctx == NULL) { + return; + } keccak_inc_finalize((uint64_t *)state->ctx, OQS_SHA3_SHA3_256_RATE, 0x06); keccak_inc_squeeze(output, 32, (uint64_t *)state->ctx, OQS_SHA3_SHA3_256_RATE); } static void SHA3_sha3_256_inc_ctx_release(OQS_SHA3_sha3_256_inc_ctx *state) { - OQS_MEM_aligned_free(state->ctx); + if (state != NULL && state->ctx != NULL) { + OQS_MEM_aligned_free(state->ctx); + state->ctx = NULL; + } } static void SHA3_sha3_256_inc_ctx_clone(OQS_SHA3_sha3_256_inc_ctx *dest, const OQS_SHA3_sha3_256_inc_ctx *src) { - memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES); + if (dest == NULL || src == NULL || src->ctx == NULL) { + return; + } + if (dest->ctx == NULL) { + dest->ctx = OQS_MEM_checked_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES); + } + if (dest->ctx != NULL) { + memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES); + } } static void SHA3_sha3_256_inc_ctx_reset(OQS_SHA3_sha3_256_inc_ctx *state) { + if (state == NULL || state->ctx == NULL) { + return; + } keccak_inc_reset((uint64_t *)state->ctx); } /* SHA3-384 */ - static void SHA3_sha3_384(uint8_t *output, const uint8_t *input, size_t inlen) { + if (output == NULL || input == NULL) { + return; + } OQS_SHA3_sha3_384_inc_ctx s; OQS_SHA3_sha3_384_inc_init(&s); - OQS_SHA3_sha3_384_inc_absorb(&s, input, inlen); - OQS_SHA3_sha3_384_inc_finalize(output, &s); - OQS_SHA3_sha3_384_inc_ctx_release(&s); + if (s.ctx != NULL) { + OQS_SHA3_sha3_384_inc_absorb(&s, input, inlen); + OQS_SHA3_sha3_384_inc_finalize(output, &s); + OQS_SHA3_sha3_384_inc_ctx_release(&s); + } } static void SHA3_sha3_384_inc_init(OQS_SHA3_sha3_384_inc_ctx *state) { + if (state == NULL) { + return; + } state->ctx = OQS_MEM_checked_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES); - keccak_inc_reset((uint64_t *)state->ctx); + if (state->ctx != NULL) { + keccak_inc_reset((uint64_t *)state->ctx); + } } static void SHA3_sha3_384_inc_absorb(OQS_SHA3_sha3_384_inc_ctx *state, const uint8_t *input, size_t inlen) { + if (state == NULL || state->ctx == NULL || input == NULL) { + return; + } keccak_inc_absorb((uint64_t *)state->ctx, OQS_SHA3_SHA3_384_RATE, input, inlen); } static void SHA3_sha3_384_inc_finalize(uint8_t *output, OQS_SHA3_sha3_384_inc_ctx *state) { + if (output == NULL || state == NULL || state->ctx == NULL) { + return; + } keccak_inc_finalize((uint64_t *)state->ctx, OQS_SHA3_SHA3_384_RATE, 0x06); keccak_inc_squeeze(output, 48, (uint64_t *)state->ctx, OQS_SHA3_SHA3_384_RATE); } static void SHA3_sha3_384_inc_ctx_release(OQS_SHA3_sha3_384_inc_ctx *state) { - OQS_MEM_aligned_free(state->ctx); + if (state != NULL && state->ctx != NULL) { + OQS_MEM_aligned_free(state->ctx); + state->ctx = NULL; + } } static void SHA3_sha3_384_inc_ctx_clone(OQS_SHA3_sha3_384_inc_ctx *dest, const OQS_SHA3_sha3_384_inc_ctx *src) { - memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES); + if (dest == NULL || src == NULL || src->ctx == NULL) { + return; + } + if (dest->ctx == NULL) { + dest->ctx = OQS_MEM_checked_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES); + } + if (dest->ctx != NULL) { + memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES); + } } static void SHA3_sha3_384_inc_ctx_reset(OQS_SHA3_sha3_384_inc_ctx *state) { + if (state == NULL || state->ctx == NULL) { + return; + } keccak_inc_reset((uint64_t *)state->ctx); } /* SHA3-512 */ static void SHA3_sha3_512(uint8_t *output, const uint8_t *input, size_t inlen) { + if (output == NULL || input == NULL) { + return; + } OQS_SHA3_sha3_512_inc_ctx s; OQS_SHA3_sha3_512_inc_init(&s); - OQS_SHA3_sha3_512_inc_absorb(&s, input, inlen); - OQS_SHA3_sha3_512_inc_finalize(output, &s); - OQS_SHA3_sha3_512_inc_ctx_release(&s); + if (s.ctx != NULL) { + OQS_SHA3_sha3_512_inc_absorb(&s, input, inlen); + OQS_SHA3_sha3_512_inc_finalize(output, &s); + OQS_SHA3_sha3_512_inc_ctx_release(&s); + } } static void SHA3_sha3_512_inc_init(OQS_SHA3_sha3_512_inc_ctx *state) { + if (state == NULL) { + return; + } state->ctx = OQS_MEM_checked_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES); - keccak_inc_reset((uint64_t *)state->ctx); + if (state->ctx != NULL) { + keccak_inc_reset((uint64_t *)state->ctx); + } } static void SHA3_sha3_512_inc_absorb(OQS_SHA3_sha3_512_inc_ctx *state, const uint8_t *input, size_t inlen) { + if (state == NULL || state->ctx == NULL || input == NULL) { + return; + } keccak_inc_absorb((uint64_t *)state->ctx, OQS_SHA3_SHA3_512_RATE, input, inlen); } static void SHA3_sha3_512_inc_finalize(uint8_t *output, OQS_SHA3_sha3_512_inc_ctx *state) { + if (output == NULL || state == NULL || state->ctx == NULL) { + return; + } keccak_inc_finalize((uint64_t *)state->ctx, OQS_SHA3_SHA3_512_RATE, 0x06); keccak_inc_squeeze(output, 64, (uint64_t *)state->ctx, OQS_SHA3_SHA3_512_RATE); } static void SHA3_sha3_512_inc_ctx_release(OQS_SHA3_sha3_512_inc_ctx *state) { - OQS_MEM_aligned_free(state->ctx); + if (state != NULL && state->ctx != NULL) { + OQS_MEM_aligned_free(state->ctx); + state->ctx = NULL; + } } static void SHA3_sha3_512_inc_ctx_clone(OQS_SHA3_sha3_512_inc_ctx *dest, const OQS_SHA3_sha3_512_inc_ctx *src) { - memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES); + if (dest == NULL || src == NULL || src->ctx == NULL) { + return; + } + if (dest->ctx == NULL) { + dest->ctx = OQS_MEM_checked_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES); + } + if (dest->ctx != NULL) { + memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES); + } } static void SHA3_sha3_512_inc_ctx_reset(OQS_SHA3_sha3_512_inc_ctx *state) { + if (state == NULL || state->ctx == NULL) { + return; + } keccak_inc_reset((uint64_t *)state->ctx); } /* SHAKE128 */ static void SHA3_shake128(uint8_t *output, size_t outlen, const uint8_t *input, size_t inlen) { + if (output == NULL || input == NULL) { + return; + } OQS_SHA3_shake128_inc_ctx s; OQS_SHA3_shake128_inc_init(&s); - OQS_SHA3_shake128_inc_absorb(&s, input, inlen); - OQS_SHA3_shake128_inc_finalize(&s); - OQS_SHA3_shake128_inc_squeeze(output, outlen, &s); - OQS_SHA3_shake128_inc_ctx_release(&s); + if (s.ctx != NULL) { + OQS_SHA3_shake128_inc_absorb(&s, input, inlen); + OQS_SHA3_shake128_inc_finalize(&s); + OQS_SHA3_shake128_inc_squeeze(output, outlen, &s); + OQS_SHA3_shake128_inc_ctx_release(&s); + } } /* SHAKE128 incremental */ static void SHA3_shake128_inc_init(OQS_SHA3_shake128_inc_ctx *state) { + if (state == NULL) { + return; + } state->ctx = OQS_MEM_checked_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES); - keccak_inc_reset((uint64_t *)state->ctx); + if (state->ctx != NULL) { + keccak_inc_reset((uint64_t *)state->ctx); + } } static void SHA3_shake128_inc_absorb(OQS_SHA3_shake128_inc_ctx *state, const uint8_t *input, size_t inlen) { + if (state == NULL || state->ctx == NULL || input == NULL) { + return; + } keccak_inc_absorb((uint64_t *)state->ctx, OQS_SHA3_SHAKE128_RATE, input, inlen); } static void SHA3_shake128_inc_finalize(OQS_SHA3_shake128_inc_ctx *state) { + if (state == NULL || state->ctx == NULL) { + return; + } keccak_inc_finalize((uint64_t *)state->ctx, OQS_SHA3_SHAKE128_RATE, 0x1F); } static void SHA3_shake128_inc_squeeze(uint8_t *output, size_t outlen, OQS_SHA3_shake128_inc_ctx *state) { + if (output == NULL || state == NULL || state->ctx == NULL) { + return; + } keccak_inc_squeeze(output, outlen, (uint64_t *)state->ctx, OQS_SHA3_SHAKE128_RATE); } static void SHA3_shake128_inc_ctx_clone(OQS_SHA3_shake128_inc_ctx *dest, const OQS_SHA3_shake128_inc_ctx *src) { - memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES); + if (dest == NULL || src == NULL || src->ctx == NULL) { + return; + } + if (dest->ctx == NULL) { + dest->ctx = OQS_MEM_checked_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES); + } + if (dest->ctx != NULL) { + memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES); + } } static void SHA3_shake128_inc_ctx_release(OQS_SHA3_shake128_inc_ctx *state) { - OQS_MEM_aligned_free(state->ctx); + if (state != NULL && state->ctx != NULL) { + OQS_MEM_aligned_free(state->ctx); + state->ctx = NULL; + } } static void SHA3_shake128_inc_ctx_reset(OQS_SHA3_shake128_inc_ctx *state) { + if (state == NULL || state->ctx == NULL) { + return; + } keccak_inc_reset((uint64_t *)state->ctx); } /* SHAKE256 */ static void SHA3_shake256(uint8_t *output, size_t outlen, const uint8_t *input, size_t inlen) { + if (output == NULL || input == NULL) { + return; + } OQS_SHA3_shake256_inc_ctx s; OQS_SHA3_shake256_inc_init(&s); - OQS_SHA3_shake256_inc_absorb(&s, input, inlen); - OQS_SHA3_shake256_inc_finalize(&s); - OQS_SHA3_shake256_inc_squeeze(output, outlen, &s); - OQS_SHA3_shake256_inc_ctx_release(&s); + if (s.ctx != NULL) { + OQS_SHA3_shake256_inc_absorb(&s, input, inlen); + OQS_SHA3_shake256_inc_finalize(&s); + OQS_SHA3_shake256_inc_squeeze(output, outlen, &s); + OQS_SHA3_shake256_inc_ctx_release(&s); + } } /* SHAKE256 incremental */ static void SHA3_shake256_inc_init(OQS_SHA3_shake256_inc_ctx *state) { + if (state == NULL) { + return; + } state->ctx = OQS_MEM_checked_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES); - keccak_inc_reset((uint64_t *)state->ctx); + if (state->ctx != NULL) { + keccak_inc_reset((uint64_t *)state->ctx); + } } static void SHA3_shake256_inc_absorb(OQS_SHA3_shake256_inc_ctx *state, const uint8_t *input, size_t inlen) { + if (state == NULL || state->ctx == NULL || input == NULL) { + return; + } keccak_inc_absorb((uint64_t *)state->ctx, OQS_SHA3_SHAKE256_RATE, input, inlen); } static void SHA3_shake256_inc_finalize(OQS_SHA3_shake256_inc_ctx *state) { + if (state == NULL || state->ctx == NULL) { + return; + } keccak_inc_finalize((uint64_t *)state->ctx, OQS_SHA3_SHAKE256_RATE, 0x1F); } static void SHA3_shake256_inc_squeeze(uint8_t *output, size_t outlen, OQS_SHA3_shake256_inc_ctx *state) { + if (output == NULL || state == NULL || state->ctx == NULL) { + return; + } keccak_inc_squeeze(output, outlen, state->ctx, OQS_SHA3_SHAKE256_RATE); } static void SHA3_shake256_inc_ctx_release(OQS_SHA3_shake256_inc_ctx *state) { - OQS_MEM_aligned_free(state->ctx); + if (state != NULL && state->ctx != NULL) { + OQS_MEM_aligned_free(state->ctx); + state->ctx = NULL; + } } static void SHA3_shake256_inc_ctx_clone(OQS_SHA3_shake256_inc_ctx *dest, const OQS_SHA3_shake256_inc_ctx *src) { - memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES); + if (dest == NULL || src == NULL || src->ctx == NULL) { + return; + } + if (dest->ctx == NULL) { + dest->ctx = OQS_MEM_checked_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES); + } + if (dest->ctx != NULL) { + memcpy(dest->ctx, src->ctx, KECCAK_CTX_BYTES); + } } static void SHA3_shake256_inc_ctx_reset(OQS_SHA3_shake256_inc_ctx *state) { + if (state == NULL || state->ctx == NULL) { + return; + } keccak_inc_reset((uint64_t *)state->ctx); } diff --git a/src/common/sha3/xkcp_sha3x4.c b/src/common/sha3/xkcp_sha3x4.c index bd441a01ff..a91d9c19b3 100644 --- a/src/common/sha3/xkcp_sha3x4.c +++ b/src/common/sha3/xkcp_sha3x4.c @@ -72,12 +72,18 @@ static void keccak_x4_inc_reset(uint64_t *s) { Keccak_X4_Dispatch(); } #endif - (*Keccak_X4_Initialize_ptr)(s); - s[100] = 0; + if (Keccak_X4_Initialize_ptr != NULL) { + (*Keccak_X4_Initialize_ptr)(s); + s[100] = 0; + } } static void keccak_x4_inc_absorb(uint64_t *s, uint32_t r, const uint8_t *in0, const uint8_t *in1, const uint8_t *in2, const uint8_t *in3, size_t inlen) { + if (s == NULL || Keccak_X4_AddBytes_ptr == NULL || Keccak_X4_Permute_ptr == NULL) { + return; + } + uint64_t c = r - s[100]; if (s[100] && inlen >= c) { @@ -115,6 +121,10 @@ static void keccak_x4_inc_absorb(uint64_t *s, uint32_t r, } static void keccak_x4_inc_finalize(uint64_t *s, uint32_t r, uint8_t p) { + if (s == NULL || Keccak_X4_AddByte_ptr == NULL) { + return; + } + (*Keccak_X4_AddByte_ptr)(s, 0, p, (unsigned int)s[100]); (*Keccak_X4_AddByte_ptr)(s, 1, p, (unsigned int)s[100]); (*Keccak_X4_AddByte_ptr)(s, 2, p, (unsigned int)s[100]); @@ -130,6 +140,9 @@ static void keccak_x4_inc_finalize(uint64_t *s, uint32_t r, uint8_t p) { static void keccak_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, size_t outlen, uint64_t *s, uint32_t r) { + if (s == NULL || Keccak_X4_ExtractBytes_ptr == NULL || Keccak_X4_Permute_ptr == NULL) { + return; + } while (outlen > s[100]) { (*Keccak_X4_ExtractBytes_ptr)(s, 0, out0, (unsigned int)(r - s[100]), (unsigned int)s[100]); @@ -158,41 +171,57 @@ static void keccak_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *out2, u static void SHA3_shake128_x4(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, size_t outlen, const uint8_t *in0, const uint8_t *in1, const uint8_t *in2, const uint8_t *in3, size_t inlen) { OQS_SHA3_shake128_x4_inc_ctx s; OQS_SHA3_shake128_x4_inc_init(&s); - OQS_SHA3_shake128_x4_inc_absorb(&s, in0, in1, in2, in3, inlen); - OQS_SHA3_shake128_x4_inc_finalize(&s); - OQS_SHA3_shake128_x4_inc_squeeze(out0, out1, out2, out3, outlen, &s); - OQS_SHA3_shake128_x4_inc_ctx_release(&s); + if (s.ctx != NULL) { + OQS_SHA3_shake128_x4_inc_absorb(&s, in0, in1, in2, in3, inlen); + OQS_SHA3_shake128_x4_inc_finalize(&s); + OQS_SHA3_shake128_x4_inc_squeeze(out0, out1, out2, out3, outlen, &s); + OQS_SHA3_shake128_x4_inc_ctx_release(&s); + } } /* SHAKE128 incremental */ static void SHA3_shake128_x4_inc_init(OQS_SHA3_shake128_x4_inc_ctx *state) { state->ctx = OQS_MEM_checked_aligned_alloc(KECCAK_X4_CTX_ALIGNMENT, KECCAK_X4_CTX_BYTES); - keccak_x4_inc_reset((uint64_t *)state->ctx); + if (state->ctx != NULL) { + keccak_x4_inc_reset((uint64_t *)state->ctx); + } } static void SHA3_shake128_x4_inc_absorb(OQS_SHA3_shake128_x4_inc_ctx *state, const uint8_t *in0, const uint8_t *in1, const uint8_t *in2, const uint8_t *in3, size_t inlen) { - keccak_x4_inc_absorb((uint64_t *)state->ctx, OQS_SHA3_SHAKE128_RATE, in0, in1, in2, in3, inlen); + if (state->ctx != NULL) { + keccak_x4_inc_absorb((uint64_t *)state->ctx, OQS_SHA3_SHAKE128_RATE, in0, in1, in2, in3, inlen); + } } static void SHA3_shake128_x4_inc_finalize(OQS_SHA3_shake128_x4_inc_ctx *state) { - keccak_x4_inc_finalize((uint64_t *)state->ctx, OQS_SHA3_SHAKE128_RATE, 0x1F); + if (state->ctx != NULL) { + keccak_x4_inc_finalize((uint64_t *)state->ctx, OQS_SHA3_SHAKE128_RATE, 0x1F); + } } static void SHA3_shake128_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, size_t outlen, OQS_SHA3_shake128_x4_inc_ctx *state) { - keccak_x4_inc_squeeze(out0, out1, out2, out3, outlen, (uint64_t *)state->ctx, OQS_SHA3_SHAKE128_RATE); + if (state->ctx != NULL) { + keccak_x4_inc_squeeze(out0, out1, out2, out3, outlen, (uint64_t *)state->ctx, OQS_SHA3_SHAKE128_RATE); + } } - static void SHA3_shake128_x4_inc_ctx_clone(OQS_SHA3_shake128_x4_inc_ctx *dest, const OQS_SHA3_shake128_x4_inc_ctx *src) { - memcpy(dest->ctx, src->ctx, KECCAK_X4_CTX_BYTES); + if (dest && src && dest->ctx && src->ctx) { + memcpy(dest->ctx, src->ctx, KECCAK_X4_CTX_BYTES); + } } static void SHA3_shake128_x4_inc_ctx_release(OQS_SHA3_shake128_x4_inc_ctx *state) { - OQS_MEM_aligned_free(state->ctx); + if (state) { + OQS_MEM_aligned_free(state->ctx); + state->ctx = NULL; + } } static void SHA3_shake128_x4_inc_ctx_reset(OQS_SHA3_shake128_x4_inc_ctx *state) { - keccak_x4_inc_reset((uint64_t *)state->ctx); + if (state && state->ctx) { + keccak_x4_inc_reset((uint64_t *)state->ctx); + } } /********** SHAKE256 ***********/ @@ -200,41 +229,60 @@ static void SHA3_shake128_x4_inc_ctx_reset(OQS_SHA3_shake128_x4_inc_ctx *state) static void SHA3_shake256_x4(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, size_t outlen, const uint8_t *in0, const uint8_t *in1, const uint8_t *in2, const uint8_t *in3, size_t inlen) { OQS_SHA3_shake256_x4_inc_ctx s; OQS_SHA3_shake256_x4_inc_init(&s); - OQS_SHA3_shake256_x4_inc_absorb(&s, in0, in1, in2, in3, inlen); - OQS_SHA3_shake256_x4_inc_finalize(&s); - OQS_SHA3_shake256_x4_inc_squeeze(out0, out1, out2, out3, outlen, &s); - OQS_SHA3_shake256_x4_inc_ctx_release(&s); + if (s.ctx != NULL) { + OQS_SHA3_shake256_x4_inc_absorb(&s, in0, in1, in2, in3, inlen); + OQS_SHA3_shake256_x4_inc_finalize(&s); + OQS_SHA3_shake256_x4_inc_squeeze(out0, out1, out2, out3, outlen, &s); + OQS_SHA3_shake256_x4_inc_ctx_release(&s); + } } /* SHAKE256 incremental */ static void SHA3_shake256_x4_inc_init(OQS_SHA3_shake256_x4_inc_ctx *state) { - state->ctx = OQS_MEM_checked_aligned_alloc(KECCAK_X4_CTX_ALIGNMENT, KECCAK_X4_CTX_BYTES); - keccak_x4_inc_reset((uint64_t *)state->ctx); + if (state) { + state->ctx = OQS_MEM_checked_aligned_alloc(KECCAK_X4_CTX_ALIGNMENT, KECCAK_X4_CTX_BYTES); + if (state->ctx) { + keccak_x4_inc_reset((uint64_t *)state->ctx); + } + } } static void SHA3_shake256_x4_inc_absorb(OQS_SHA3_shake256_x4_inc_ctx *state, const uint8_t *in0, const uint8_t *in1, const uint8_t *in2, const uint8_t *in3, size_t inlen) { - keccak_x4_inc_absorb((uint64_t *)state->ctx, OQS_SHA3_SHAKE256_RATE, in0, in1, in2, in3, inlen); + if (state && state->ctx) { + keccak_x4_inc_absorb((uint64_t *)state->ctx, OQS_SHA3_SHAKE256_RATE, in0, in1, in2, in3, inlen); + } } static void SHA3_shake256_x4_inc_finalize(OQS_SHA3_shake256_x4_inc_ctx *state) { - keccak_x4_inc_finalize((uint64_t *)state->ctx, OQS_SHA3_SHAKE256_RATE, 0x1F); + if (state && state->ctx) { + keccak_x4_inc_finalize((uint64_t *)state->ctx, OQS_SHA3_SHAKE256_RATE, 0x1F); + } } static void SHA3_shake256_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_t *out3, size_t outlen, OQS_SHA3_shake256_x4_inc_ctx *state) { - keccak_x4_inc_squeeze(out0, out1, out2, out3, outlen, (uint64_t *)state->ctx, OQS_SHA3_SHAKE256_RATE); + if (state && state->ctx) { + keccak_x4_inc_squeeze(out0, out1, out2, out3, outlen, (uint64_t *)state->ctx, OQS_SHA3_SHAKE256_RATE); + } } static void SHA3_shake256_x4_inc_ctx_clone(OQS_SHA3_shake256_x4_inc_ctx *dest, const OQS_SHA3_shake256_x4_inc_ctx *src) { - memcpy(dest->ctx, src->ctx, KECCAK_X4_CTX_BYTES); + if (dest && src && dest->ctx && src->ctx) { + memcpy(dest->ctx, src->ctx, KECCAK_X4_CTX_BYTES); + } } static void SHA3_shake256_x4_inc_ctx_release(OQS_SHA3_shake256_x4_inc_ctx *state) { - OQS_MEM_aligned_free(state->ctx); + if (state) { + OQS_MEM_aligned_free(state->ctx); + state->ctx = NULL; + } } static void SHA3_shake256_x4_inc_ctx_reset(OQS_SHA3_shake256_x4_inc_ctx *state) { - keccak_x4_inc_reset((uint64_t *)state->ctx); + if (state && state->ctx) { + keccak_x4_inc_reset((uint64_t *)state->ctx); + } } extern struct OQS_SHA3_x4_callbacks sha3_x4_default_callbacks;