Skip to content

Commit

Permalink
Replace MessageDigest.isEqual with our own implementation
Browse files Browse the repository at this point in the history
+ The documentation of MessageDigest.isEqual does not guarantee constant
  time despite its implementation for most JDKs is constant time.
  • Loading branch information
amirhosv authored and geedo0 committed Oct 2, 2023
1 parent 2e719c7 commit 35897a2
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 18 deletions.
5 changes: 2 additions & 3 deletions src/com/amazon/corretto/crypto/provider/AesGcmSpi.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.spec.AlgorithmParameterSpec;
Expand Down Expand Up @@ -340,7 +339,7 @@ protected void engineInit(
throw new InvalidKeyException("Key doesn't support encoding");
}

if (!MessageDigest.isEqual(encodedKey, this.key)) {
if (!ConstantTime.equals(this.key, encodedKey)) {
if (encodedKey.length != 128 / 8
&& encodedKey.length != 192 / 8
&& encodedKey.length != 256 / 8) {
Expand Down Expand Up @@ -372,7 +371,7 @@ protected void engineInit(
&& this.key != null
&& (jceOpMode == Cipher.ENCRYPT_MODE || jceOpMode == Cipher.WRAP_MODE)) {
if (Arrays.equals(this.iv, iv)
&& (encodedKey == null || MessageDigest.isEqual(this.key, encodedKey))) {
&& (encodedKey == null || ConstantTime.equals(this.key, encodedKey))) {
throw new InvalidAlgorithmParameterException(
"Cannot reuse same iv and key for GCM encryption");
}
Expand Down
6 changes: 1 addition & 5 deletions src/com/amazon/corretto/crypto/provider/AesXtsSpi.java
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,7 @@ private boolean checkKeyTweakEquality(final byte[] key, final byte[] tweak) {
if (packedTweakKey[i] != tweak[i]) return false;
}

for (int i = 0; i != KEY_SIZE_IN_BYTES; i++) {
if (packedTweakKey[i + TWEAK_SIZE_IN_BYTES] != key[i]) return false;
}

return true;
return ConstantTime.equals(packedTweakKey, TWEAK_SIZE_IN_BYTES, KEY_SIZE_IN_BYTES, key);
}

@Override
Expand Down
57 changes: 50 additions & 7 deletions src/com/amazon/corretto/crypto/provider/ConstantTime.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,47 @@

/** Contains several constant time utilities */
final class ConstantTime {
private static int LONG_UNSIGNED_SHIFT = Long.SIZE - 1;

private ConstantTime() {
// Prevent instantiation
}

/** Equivalent to {@code val != 0 ? 1 : 0} */
static final int isNonZero(int val) {
static int isNonZero(int val) {
return ((val | -val) >>> 31) & 0x01; // Unsigned bitshift
}

/** Equivalent to {@code val == 0 ? 1 : 0} */
static final int isZero(int val) {
static int isZero(int val) {
return 1 - isNonZero(val);
}

/** Equivalent to {@code val < 0 ? 1 : 0} */
static final int isNegative(int val) {
static int isNegative(int val) {
return (val >>> 31) & 0x01;
}

/** Equivalent to {@code x == y ? 1 : 0} */
static final int equal(int x, int y) {
static int equal(int x, int y) {
final int difference = x - y;
// Difference is 0 iff x == y
return isZero(difference);
}

/** Equivalent to {@code x > y ? 1 : 0} */
static final int gt(int x, int y) {
static int gt(int x, int y) {
// Convert to long to avoid underflow
final long xl = x;
final long yl = y;
final long difference = yl - xl;
// If xl > yl, then difference is negative.
// Thus, we can just return the sign-bit
return (int) ((difference >>> 63) & 0x01); // Unsigned bitshift
return (int) (difference >>> LONG_UNSIGNED_SHIFT); // Unsigned bitshift
}

/** Equivalent to {@code selector != 0 ? a : b} */
static final int select(int selector, int a, int b) {
static int select(int selector, int a, int b) {
final int mask = isZero(selector) - 1;
// Mask == -1 (all bits 1) iff selector != 0
// Mask == 0 (all bits 0) iff selector == 0
Expand All @@ -51,4 +53,45 @@ static final int select(int selector, int a, int b) {

return b ^ (combined & mask);
}

/**
* @return true iff all the bytes in the specified ranges are equal
*/
static boolean equals(
final byte[] a,
final int aStart,
final int aLen,
final byte[] b,
final int bStart,
final int bLen) {

Utils.checkArrayLimits(a, aStart, aLen);
Utils.checkArrayLimits(b, bStart, bLen);

if (aLen != bLen) {
return false;
}

int result = 0;

for (int i = 0; i < aLen; i++) {
result |= a[aStart + i] ^ b[bStart + i];
}

return result == 0;
}

static boolean equals(final byte[] a, final int aStart, final int aLen, final byte[] b) {
return equals(a, aStart, aLen, b, 0, b.length);
}

static boolean equals(final byte[] a, final byte[] b) {
if (a == b) {
return true;
}
if (a == null || b == null) {
return false;
}
return equals(a, 0, a.length, b);
}
}
3 changes: 1 addition & 2 deletions src/com/amazon/corretto/crypto/provider/EvpKey.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import java.security.GeneralSecurityException;
import java.security.Key;
import java.security.KeyFactory;
import java.security.MessageDigest;
import java.security.PublicKey;
import java.security.spec.AlgorithmParameterSpec;
import java.security.spec.InvalidKeySpecException;
Expand Down Expand Up @@ -160,7 +159,7 @@ public boolean equals(final Object obj) {
}

// Constant time equality check
return MessageDigest.isEqual(internalGetEncoded(), otherEncoded);
return ConstantTime.equals(internalGetEncoded(), otherEncoded);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
// SPDX-License-Identifier: Apache-2.0
package com.amazon.corretto.crypto.provider.test;

import static com.amazon.corretto.crypto.provider.test.TestUtil.sneakyInvoke_boolean;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.params.provider.Arguments.arguments;

import java.util.ArrayList;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;
Expand All @@ -18,7 +22,7 @@
// correct answers.
@ExtendWith(TestResultLogger.class)
@Execution(ExecutionMode.CONCURRENT)
public class ConstantTimeTests {
public class ConstantTimeTest {
// A few common values which when combined can trigger edge cases
private static final int[] TEST_VALUES = {
Integer.MIN_VALUE,
Expand Down Expand Up @@ -122,4 +126,31 @@ private static int sneaky(String name, int a, int b, int c) {
throw new AssertionError(t);
}
}

@Test
public void testConstantTimeEquality() throws Throwable {
assertTrue(sneakyInvoke_boolean(CONSTANT_TIME_CLASS, "equals", new byte[0], new byte[0]));
final byte[] trusted = {0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7};
final byte[] other = {0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7};
assertTrue(sneakyInvoke_boolean(CONSTANT_TIME_CLASS, "equals", trusted, trusted));
assertTrue(sneakyInvoke_boolean(CONSTANT_TIME_CLASS, "equals", trusted, other));
assertTrue(sneakyInvoke_boolean(CONSTANT_TIME_CLASS, "equals", trusted, 0, 8, trusted, 8, 8));
assertTrue(sneakyInvoke_boolean(CONSTANT_TIME_CLASS, "equals", trusted, 8, 8, trusted, 0, 8));

assertFalse(sneakyInvoke_boolean(CONSTANT_TIME_CLASS, "equals", new byte[5], new byte[6]));
assertFalse(sneakyInvoke_boolean(CONSTANT_TIME_CLASS, "equals", trusted, new byte[0]));
assertFalse(sneakyInvoke_boolean(CONSTANT_TIME_CLASS, "equals", trusted, 0, 8, trusted, 8, 7));
assertFalse(sneakyInvoke_boolean(CONSTANT_TIME_CLASS, "equals", trusted, 8, 6, trusted, 0, 10));
assertFalse(sneakyInvoke_boolean(CONSTANT_TIME_CLASS, "equals", trusted, 0, 8, other, 8, 7));
trusted[0]++;
assertFalse(sneakyInvoke_boolean(CONSTANT_TIME_CLASS, "equals", trusted, 0, 8, trusted, 8, 8));
assertFalse(sneakyInvoke_boolean(CONSTANT_TIME_CLASS, "equals", trusted, 8, 8, trusted, 0, 8));
assertTrue(sneakyInvoke_boolean(CONSTANT_TIME_CLASS, "equals", trusted, 1, 7, trusted, 9, 7));
assertTrue(sneakyInvoke_boolean(CONSTANT_TIME_CLASS, "equals", trusted, 9, 7, trusted, 1, 7));
assertFalse(sneakyInvoke_boolean(CONSTANT_TIME_CLASS, "equals", trusted, other));

assertTrue(sneakyInvoke_boolean(CONSTANT_TIME_CLASS, "equals", null, null));
assertFalse(sneakyInvoke_boolean(CONSTANT_TIME_CLASS, "equals", null, other));
assertFalse(sneakyInvoke_boolean(CONSTANT_TIME_CLASS, "equals", trusted, null));
}
}
5 changes: 5 additions & 0 deletions tst/com/amazon/corretto/crypto/provider/test/TestUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ public static int sneakyInvoke_int(Object o, String methodName, Object... args)
return (Integer) sneakyInvoke(o, methodName, args);
}

public static boolean sneakyInvoke_boolean(Object o, String methodName, Object... args)
throws Throwable {
return (Boolean) sneakyInvoke(o, methodName, args);
}

public static <T> T sneakyInvoke(Object o, String methodName, Object... args) throws Throwable {
Class<?> klass;
Object receiver;
Expand Down

0 comments on commit 35897a2

Please sign in to comment.