Skip to content

Commit

Permalink
Add column for serializing a packed bit field. (#498)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #498

# Background:

Currently in order to successfully use UDP, you must write some carefully crafted code that will take all the rows of metadata for one side and package it into a collection of bytes. Afterwards the caller will get a `SecString` object back which is a bit representation of all the bytes they passed in, minus the filtered out rows. The user must then extract the corresponding bits for each column into separate MPC Types.  This is a cumbersome process which is error prone, as you must make sure to carefully match up the two steps and any changes can cause a bug.

# This Diff

Adds support for a packed bit field column. The current plan is that the user will not directly interface with this column type, but it is internal to the row structure. It will support up to 8 different boolean values to be packed inside of it. The reason for this is that a bool value takes up a whole byte in memory. This means if we naively store these, each bool will actually take 8 bits in the final row structure. This allows us to fit 8 bool columns into one byte!

Reviewed By: haochenuw

Differential Revision: D43289316

fbshipit-source-id: 0100109c1a691a2eae7d4691943c11bdb21ef427
  • Loading branch information
Tal Davidi authored and facebook-github-bot committed Feb 22, 2023
1 parent d2c75d1 commit 144c868
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include "fbpcf/frontend/MPCTypes.h"
#include "fbpcf/mpc_std_lib/unified_data_process/serialization/IColumnDefinition.h"

#include <string>
namespace fbpcf::mpc_std_lib::unified_data_process::serialization {

template <int schedulerId>
class PackedBitFieldColumn : public IColumnDefinition<schedulerId> {
using NativeType = std::vector<bool>;
using MPCTypes = frontend::MPCTypes<schedulerId, true>;

public:
PackedBitFieldColumn(
std::string columnName,
std::vector<std::string> subColumnNames)
: columnName_{columnName}, subColumnNames_{subColumnNames} {
if (subColumnNames.size() > 8) {
throw std::runtime_error(
"Can only pack 8 bits into a byte. Please create another PackedBitField"
" if you would like to store additional boolean values");
}
}

std::string getColumnName() const override {
return columnName_;
}

const std::vector<std::string>& getSubColumnNames() const {
return subColumnNames_;
}

size_t getColumnSizeBytes() const override {
return 1;
}

// input to this function is a pointer to a bool vector since memory layout
// is not guaranteed by compiler (i.e. can not get a bool* from a
// vector<bool>.data())
void serializeColumnAsPlaintextBytes(
const void* inputData,
unsigned char* buf) const override {
const NativeType value = *((NativeType*)inputData);
if (value.size() != subColumnNames_.size()) {
throw std::runtime_error(
"Size mismatch between expected number of packed bits and actual data");
}

unsigned char toWrite = 0;
for (size_t i = 0; i < value.size(); ++i) {
toWrite |= (value[i] << i);
}
buf[0] = toWrite;
}

typename IColumnDefinition<schedulerId>::DeserializeType
deserializeSharesToMPCType(
const std::vector<std::vector<unsigned char>>& serializedSecretShares,
size_t offset) const override {
std::vector<std::vector<bool>> reconstructedShares(
subColumnNames_.size(),
std::vector<bool>(serializedSecretShares.size()));

for (int i = 0; i < serializedSecretShares.size(); i++) {
unsigned char packedValues = serializedSecretShares[i][offset];
for (int j = 0; j < subColumnNames_.size(); j++) {
reconstructedShares[j][i] = (packedValues >> j) & 1;
}
}

std::vector<typename MPCTypes::SecBool> rst(reconstructedShares.size());
for (int i = 0; i < reconstructedShares.size(); i++) {
rst[i] = typename MPCTypes::SecBool(
typename MPCTypes::SecBool::ExtractedBit(reconstructedShares[i]));
}

return rst;
}

private:
std::string columnName_;
std::vector<std::string> subColumnNames_;
};

} // namespace fbpcf::mpc_std_lib::unified_data_process::serialization
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
#include "fbpcf/scheduler/SchedulerHelper.h"
#include "fbpcf/test/TestHelper.h"

#include "fbpcf/mpc_std_lib/unified_data_process/serialization/FixedSizeArrayColumn.h"
#include "fbpcf/mpc_std_lib/unified_data_process/serialization/IntegerColumn.h"
#include "fbpcf/mpc_std_lib/unified_data_process/serialization/PackedBitFieldColumn.h"

namespace fbpcf::mpc_std_lib::unified_data_process::serialization {

Expand Down Expand Up @@ -75,6 +77,32 @@ static std::vector<std::vector<int32_t>> deserializeAndRevealInt32Vector(
return rst;
}

template <int schedulerId>
static std::vector<std::vector<bool>> deserializeAndRevealPackedBits(
fbpcf::scheduler::ISchedulerFactory<true>& schedulerFactory,
const std::vector<std::vector<unsigned char>>& serializedSecretShares,
IColumnDefinition<schedulerId>& serializer) {
auto scheduler = schedulerFactory.create();

fbpcf::scheduler::SchedulerKeeper<schedulerId>::setScheduler(
std::move(scheduler));

typename IColumnDefinition<schedulerId>::DeserializeType mpcValue =
serializer.deserializeSharesToMPCType(serializedSecretShares, 0);
std::vector<typename frontend::MPCTypes<schedulerId>::SecBool> visitedVal =
std::get<std::vector<typename frontend::MPCTypes<schedulerId>::SecBool>>(
mpcValue);

std::vector<std::vector<bool>> rst(
visitedVal.size(), std::vector<bool>(visitedVal[0].getBatchSize()));

for (int i = 0; i < visitedVal.size(); i++) {
rst[i] = visitedVal[i].openToParty(0).getValue();
}

return rst;
}

TEST(SerializationTest, IntegerColumnTest) {
auto factories = fbpcf::engine::communication::getInMemoryAgentFactory(2);

Expand Down Expand Up @@ -207,4 +235,68 @@ TEST(SerializationTest, ArrayColumnTest) {
}
}

TEST(SerializationTest, PackedBitFieldColumnTest) {
auto factories = fbpcf::engine::communication::getInMemoryAgentFactory(2);

auto schedulerFactory0 =
fbpcf::scheduler::NetworkPlaintextSchedulerFactory<true>(
0, *factories[0]);

auto schedulerFactory1 =
fbpcf::scheduler::NetworkPlaintextSchedulerFactory<true>(
1, *factories[1]);

const size_t batchSize = 100;
const size_t numBits = 7;

std::random_device rd;
std::mt19937_64 e(rd());
std::uniform_int_distribution<> dist(0, 1);

PackedBitFieldColumn<0> serializer0(
"testColumnName", std::vector<std::string>(7, "testColumnName"));
PackedBitFieldColumn<1> serializer1(
"testColumnName", std::vector<std::string>(7, "testColumnName"));
EXPECT_EQ(serializer0.getColumnSizeBytes(), 1);

std::vector<std::vector<uint8_t>> bufs(
batchSize, std::vector<uint8_t>(serializer0.getColumnSizeBytes()));

std::vector<std::vector<bool>> vals(numBits, std::vector<bool>(batchSize));

for (int i = 0; i < batchSize; i++) {
std::vector<bool> val(numBits);
for (int j = 0; j < numBits; j++) {
val[j] = dist(e);
vals[j][i] = val[j];
}

serializer0.serializeColumnAsPlaintextBytes(&val, bufs[i].data());

for (int j = 0; j < numBits; j++) {
EXPECT_EQ((bufs[i][0] >> j) & 1, val[j]);
}
}

auto future0 = std::async([&schedulerFactory0, &bufs, &serializer0]() {
return deserializeAndRevealPackedBits<0>(
schedulerFactory0, bufs, serializer0);
});

auto future1 = std::async([&schedulerFactory1, &serializer1]() {
return deserializeAndRevealPackedBits<1>(
schedulerFactory1,
std::vector<std::vector<uint8_t>>(
batchSize, std::vector<uint8_t>(serializer1.getColumnSizeBytes())),
serializer1);
});

auto rst = future0.get();
future1.get();

EXPECT_EQ(rst.size(), numBits);
for (int j = 0; j < numBits; j++) {
testVectorEq(vals[j], rst[j]);
}
}
} // namespace fbpcf::mpc_std_lib::unified_data_process::serialization

0 comments on commit 144c868

Please sign in to comment.