Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MPC-AES decryption circuit #447

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fbpcf/mpc_std_lib/aes_circuit/AesCircuit.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class AesCircuit : public IAesCircuit<BitType> {
void inverseMixColumnsInPlace(WordType& src) const;

void shiftRowInPlace(std::array<WordType, 4>& src) const;

void inverseShiftRowInPlace(std::array<WordType, 4>& src) const;
#ifdef AES_CIRCUIT_TEST_FRIENDS
AES_CIRCUIT_TEST_FRIENDS;
#endif
Expand Down
79 changes: 76 additions & 3 deletions fbpcf/mpc_std_lib/aes_circuit/AesCircuit_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,65 @@ std::vector<BitType> AesCircuit<BitType>::encrypt_impl(
return convertFromWords(plaintextBlocks);
}

// implementation based on https://engineering.purdue.edu/kak/compsec/NewLectures/Lecture8.pdf
template <typename BitType>
std::vector<BitType> AesCircuit<BitType>::decrypt_impl(
const std::vector<BitType>& /* ciphertext */,
const std::vector<BitType>& /* expandedDecKey */) const {
throw std::runtime_error("Not implemented!");
const std::vector<BitType>& ciphertext,
const std::vector<BitType>& expandedDecKey) const {
// prepare input
auto ciphertextBlocks = convertToWords(ciphertext);
auto roundKeys = convertToWords(expandedDecKey);
size_t blockNo = ciphertextBlocks.size();

int round = 10;
// pre-round
for (int block = 0; block < blockNo; ++block) {
for (int word = 0; word < 4; ++word) {
for (int byte = 0; byte < 4; ++byte) {
for (int bit = 0; bit < 8; ++bit) {
ciphertextBlocks[block][word][byte][bit] =
ciphertextBlocks[block][word][byte][bit] ^
roundKeys[round][word][byte][bit];
}
}
}
}
// rounds 1 - 10
for (int round = 9; round >= 0; --round) {
// InverseShiftRows
for (int block = 0; block < blockNo; ++block) {
inverseShiftRowInPlace(ciphertextBlocks[block]);
}
// InverseSbox
for (int block = 0; block < blockNo; ++block) {
for (int word = 0; word < 4; ++word) {
for (int byte = 0; byte < 4; ++byte) {
inverseSBoxInPlace(ciphertextBlocks[block][word][byte]);
}
}
}
// AddRoundKey
for (int block = 0; block < blockNo; ++block) {
for (int word = 0; word < 4; ++word) {
for (int byte = 0; byte < 4; ++byte) {
for (int bit = 0; bit < 8; ++bit) {
ciphertextBlocks[block][word][byte][bit] =
ciphertextBlocks[block][word][byte][bit] ^
roundKeys[round][word][byte][bit];
}
}
}
}
// InverseMixColumns except for 10-th Round
if (round != 0) {
for (int block = 0; block < blockNo; ++block) {
for (int word = 0; word < 4; ++word) {
inverseMixColumnsInPlace(ciphertextBlocks[block][word]);
}
}
}
}
return convertFromWords(ciphertextBlocks);
}

template <typename BitType>
Expand Down Expand Up @@ -494,4 +548,23 @@ void AesCircuit<BitType>::shiftRowInPlace(std::array<WordType, 4>& src) const {
std::swap(src[1][row], src[0][row]);
}

template <typename BitType>
void AesCircuit<BitType>::inverseShiftRowInPlace(
std::array<WordType, 4>& src) const {
// 1st row is not shifted, 2nd row shifted right by 1
int row = 1;
std::swap(src[2][row], src[3][row]);
std::swap(src[1][row], src[2][row]);
std::swap(src[0][row], src[1][row]);
// 3rd row shifted right by 2
row++;
std::swap(src[0][row], src[2][row]);
std::swap(src[1][row], src[3][row]);
// 4th row shifted right by 3
row++;
std::swap(src[0][row], src[1][row]);
std::swap(src[1][row], src[2][row]);
std::swap(src[2][row], src[3][row]);
}

} // namespace fbpcf::mpc_std_lib::aes_circuit
87 changes: 87 additions & 0 deletions fbpcf/mpc_std_lib/aes_circuit/test/AesCircuitTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,28 @@ class AesCircuitTests : public AesCircuit<BitType> {
}
}

void testInverseShiftRowInPlace(std::vector<bool> plaintext) {
std::array<std::array<std::array<bool, 8>, 4>, 4> block;
for (int k = 0; k < 4; ++k) {
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 8; j++) {
block[k][i][j] = plaintext[32 * k + 8 * i + j];
}
}
}

AesCircuit<bool>::inverseShiftRowInPlace(block);
for (int k = 0; k < 4; ++k) {
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 8; j++) {
EXPECT_EQ(
block[k][i][j],
plaintext[32 * ((((k - i) % 4) + 4) % 4) + 8 * i + j]);
}
}
}
}

void testWordConversion() {
using ByteType = std::array<bool, 8>;
using WordType = std::array<ByteType, 4>;
Expand Down Expand Up @@ -159,6 +181,12 @@ TEST(AesCircuitTest, testShiftRowInPlace) {
test.testShiftRowInPlace(plaintext);
}

TEST(AesCircuitTest, testInverseShiftRowInPlace) {
auto plaintext = generateRandomPlaintext();
AesCircuitTests<bool> test;
test.testInverseShiftRowInPlace(plaintext);
}

TEST(AesCircuitTest, testWordConversion) {
AesCircuitTests<bool> test;
test.testWordConversion();
Expand Down Expand Up @@ -352,6 +380,65 @@ TEST(AesCircuitTest, testAesCircuitEncrypt) {
testAesCircuitEncrypt(std::make_unique<AesCircuitFactory<bool>>());
}

void testAesCircuitDecrypt(
std::shared_ptr<AesCircuitFactory<bool>> AesCircuitFactory) {
auto AesCircuit = AesCircuitFactory->create();

std::random_device rd;
std::mt19937_64 e(rd());
std::uniform_int_distribution<uint8_t> dist(0, 0xFF);
size_t blockNo = dist(e);

// generate random key
__m128i key = _mm_set_epi32(dist(e), dist(e), dist(e), dist(e));
// generate random plaintext
std::vector<uint8_t> plaintext;
plaintext.reserve(blockNo * 16);
for (int i = 0; i < blockNo * 16; ++i) {
plaintext.push_back(dist(e));
}
std::vector<__m128i> plaintextAES;
loadValueToLocalAes(plaintext, plaintextAES);

// expand key
engine::util::Aes truthAes(key);
auto expandedKey = truthAes.expandEncryptionKey(key);
// extract key and plaintext
std::vector<uint8_t> extractedKeys;
extractedKeys.reserve(176);
for (auto keyb : expandedKey) {
loadValueFromLocalAes(keyb, extractedKeys);
}

// convert key and plaintext into bool vector
std::vector<bool> keyBits;
keyBits.reserve(1408);
int8VecToBinaryVec(extractedKeys, keyBits);
std::vector<bool> plaintextBits;
plaintextBits.reserve(blockNo * 128);
int8VecToBinaryVec(plaintext, plaintextBits);

// encrypt in real aes
truthAes.encryptInPlace(plaintextAES);

// extract ciphertext in real aes
std::vector<uint8_t> ciphertextTruth;
ciphertextTruth.reserve(blockNo * 16);
for (auto b : plaintextAES) {
loadValueFromLocalAes(b, ciphertextTruth);
}
std::vector<bool> cipherextBitsTruth;
cipherextBitsTruth.reserve(blockNo * 128);
int8VecToBinaryVec(ciphertextTruth, cipherextBitsTruth);
// decrypt this ciphertext using our decrypt circuit
auto decryptionBits = AesCircuit->decrypt(cipherextBitsTruth, keyBits);
testVectorEq(decryptionBits, plaintextBits);
}

TEST(AesCircuitTest, testAesCircuitDecrypt) {
testAesCircuitDecrypt(std::make_unique<AesCircuitFactory<bool>>());
}

void testAesCircuitCtr(
std::shared_ptr<AesCircuitCtrFactory<bool>> AesCircuitCtrFactory) {
auto AesCircuitCtr = AesCircuitCtrFactory->create();
Expand Down