Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Corrected operator== for the family of cryptoparameters classes #914

Merged
merged 3 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/core/include/lattice/hal/default/ildcrtparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ class ILDCRTParams final : public ElemParams<IntType> {
return 1;
}

private:
protected:
std::ostream& doprint(std::ostream& out) const override {
out << "ILDCRTParams ";
ElemParams<IntType>::doprint(out);
Expand All @@ -355,6 +355,7 @@ class ILDCRTParams final : public ElemParams<IntType> {
return out << std::endl;
}

private:
// array of smaller ILParams
std::vector<std::shared_ptr<ILNativeParams>> m_params;
};
Expand Down
2 changes: 1 addition & 1 deletion src/core/include/lattice/hal/default/ilparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class ILParamsImpl final : public ElemParams<IntType> {
return 1;
}

private:
protected:
std::ostream& doprint(std::ostream& out) const override {
out << "ILParams ";
ElemParams<IntType>::doprint(out);
Expand Down
72 changes: 30 additions & 42 deletions src/pke/include/encoding/ckkspackedencoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class CKKSPackedEncoding : public PlaintextImpl {
std::is_same<T, NativePoly::Params>::value ||
std::is_same<T, DCRTPoly::Params>::value,
bool>::type = true>
CKKSPackedEncoding(std::shared_ptr<T> vp, EncodingParams ep) : PlaintextImpl(vp, ep, CKKSRNS_SCHEME) {
CKKSPackedEncoding(std::shared_ptr<T> vp, EncodingParams ep) : PlaintextImpl(vp, ep, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME) {
this->slots = GetDefaultSlotSize();
if (this->slots > (GetElementRingDimension() / 2)) {
OPENFHE_THROW("The number of slots cannot be larger than half of ring dimension");
Expand All @@ -83,7 +83,7 @@ class CKKSPackedEncoding : public PlaintextImpl {
bool>::type = true>
CKKSPackedEncoding(std::shared_ptr<T> vp, EncodingParams ep, const std::vector<std::complex<double>>& coeffs,
size_t noiseScaleDeg, uint32_t level, double scFact, size_t slots)
: PlaintextImpl(vp, ep, CKKSRNS_SCHEME), value(coeffs) {
: PlaintextImpl(vp, ep, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME), value(coeffs) {
// validate the number of slots
if ((slots & (slots - 1)) != 0) {
OPENFHE_THROW("The number of slots should be a power of two");
Expand All @@ -109,7 +109,7 @@ class CKKSPackedEncoding : public PlaintextImpl {
* @param rhs - The input object to copy.
*/
explicit CKKSPackedEncoding(const std::vector<std::complex<double>>& rhs, size_t slots)
: PlaintextImpl(std::shared_ptr<Poly::Params>(0), nullptr, CKKSRNS_SCHEME), value(rhs) {
: PlaintextImpl(std::shared_ptr<Poly::Params>(0), nullptr, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME), value(rhs) {
// validate the number of slots
if ((slots & (slots - 1)) != 0) {
OPENFHE_THROW("The number of slots should be a power of two");
Expand All @@ -128,7 +128,7 @@ class CKKSPackedEncoding : public PlaintextImpl {
/**
* @brief Default empty constructor with empty uninitialized data elements.
*/
CKKSPackedEncoding() : PlaintextImpl(std::shared_ptr<Poly::Params>(0), nullptr, CKKSRNS_SCHEME) {
CKKSPackedEncoding() : PlaintextImpl(std::shared_ptr<Poly::Params>(0), nullptr, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME) {
this->slots = GetDefaultSlotSize();
if (this->slots > (GetElementRingDimension() / 2)) {
OPENFHE_THROW("The number of slots cannot be larger than half of ring dimension");
Expand All @@ -147,7 +147,7 @@ class CKKSPackedEncoding : public PlaintextImpl {
OPENFHE_THROW("CKKSPackedEncoding::Decode() is not implemented. Use CKKSPackedEncoding::Decode(...) instead.");
}

bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode);
bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode) override;

const std::vector<std::complex<double>>& GetCKKSPackedValue() const override {
return value;
Expand Down Expand Up @@ -175,14 +175,6 @@ class CKKSPackedEncoding : public PlaintextImpl {
const std::vector<DCRTPoly::Integer>& b,
const std::vector<DCRTPoly::Integer>& mods);

/**
* GetEncodingType
* @return CKKS_PACKED_ENCODING
*/
PlaintextEncodings GetEncodingType() const override {
return CKKS_PACKED_ENCODING;
}

/**
* Get method to return the length of plaintext
*
Expand Down Expand Up @@ -215,40 +207,16 @@ class CKKSPackedEncoding : public PlaintextImpl {
value.resize(siz);
}

/**
* Method to compare two plaintext to test for equivalence. This method does
* not test that the plaintext are of the same type.
*
* @param other - the other plaintext to compare to.
* @return whether the two plaintext are equivalent.
*/
bool CompareTo(const PlaintextImpl& other) const override {
const auto& rv = static_cast<const CKKSPackedEncoding&>(other);
return this->value == rv.value;
}

/**
* @brief Destructor method.
*/
static void Destroy();

void PrintValue(std::ostream& out) const override {
// for sanity's sake, trailing zeros get elided into "..."
// out.precision(15);
out << "(";
size_t i = value.size();
while (--i > 0)
if (value[i] != std::complex<double>(0, 0))
break;

for (size_t j = 0; j <= i; j++) {
out << value[j].real() << ", ";
}

out << " ... ); ";
out << "Estimated precision: " << encodingParams->GetPlaintextModulus() - m_logError << " bits" << std::endl;
}

/**
* @brief GetFormattedValues() is called by operator<< and requires a precision as an argument
* @param precision number of decimal digits of precision to print
* @return string with all values and "estimated precision"
*/
std::string GetFormattedValues(int64_t precision) const override {
std::stringstream ss;
ss << "(";
Expand Down Expand Up @@ -279,10 +247,30 @@ class CKKSPackedEncoding : public PlaintextImpl {
double m_logError = 0;

protected:
void PrintValue(std::ostream& out) const override {
out << GetFormattedValues(8) << std::endl;
}

usint GetDefaultSlotSize() {
auto batchSize = GetEncodingParams()->GetBatchSize();
return (0 == batchSize) ? GetElementRingDimension() / 2 : batchSize;
}

/**
* Method to compare two plaintext to test for equivalence. This method does
* not test that the plaintext are of the same type.
*
* @param rhs - the other plaintext to compare to.
* @return whether the two plaintext are equivalent.
*/
bool CompareTo(const PlaintextImpl& rhs) const override {
const auto* el = dynamic_cast<const CKKSPackedEncoding*>(&rhs);
if (el == nullptr)
return false;

return this->value == el->value;
}

/**
* Set modulus and recalculates the vector values to fit the modulus
*
Expand Down
99 changes: 51 additions & 48 deletions src/pke/include/encoding/coefpackedencoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,106 +48,109 @@ namespace lbcrypto {
class CoefPackedEncoding : public PlaintextImpl {
std::vector<int64_t> value;

protected:
/**
* @brief PrintValue() is called by operator<<
* @param out stream to print to
*/
void PrintValue(std::ostream& out) const override {
out << "(";

// for sanity's sake: get rid of all trailing zeroes and print "..." instead
size_t i = value.size();
bool allZeroes = true;
while (i > 0) {
--i;
if (value[i] != 0) {
allZeroes = false;
break;
}
}

if (allZeroes == false) {
for (size_t j = 0; j <= i; ++j)
out << value[j] << ", ";
}
out << "... )";
}

/**
* Method to compare two plaintext to test for equivalence
* Testing that the plaintexts are of the same type done in operator==
*
* @param rhs - the other plaintext to compare to.
* @return whether the two plaintext are equivalent.
*/
bool CompareTo(const PlaintextImpl& rhs) const override {
const auto* el = dynamic_cast<const CoefPackedEncoding*>(&rhs);
if (el == nullptr)
return false;

return this->value == el->value;
}

public:
template <typename T, typename std::enable_if<std::is_same<T, Poly::Params>::value ||
std::is_same<T, NativePoly::Params>::value ||
std::is_same<T, DCRTPoly::Params>::value,
bool>::type = true>
CoefPackedEncoding(std::shared_ptr<T> vp, EncodingParams ep, SCHEME schemeId = SCHEME::INVALID_SCHEME)
: PlaintextImpl(vp, ep, schemeId) {}
: PlaintextImpl(vp, ep, COEF_PACKED_ENCODING, schemeId) {}

template <typename T, typename std::enable_if<std::is_same<T, Poly::Params>::value ||
std::is_same<T, NativePoly::Params>::value ||
std::is_same<T, DCRTPoly::Params>::value,
bool>::type = true>
CoefPackedEncoding(std::shared_ptr<T> vp, EncodingParams ep, const std::vector<int64_t>& coeffs,
SCHEME schemeId = SCHEME::INVALID_SCHEME)
: PlaintextImpl(vp, ep, schemeId), value(coeffs) {}
: PlaintextImpl(vp, ep, COEF_PACKED_ENCODING, schemeId), value(coeffs) {}

virtual ~CoefPackedEncoding() = default;
~CoefPackedEncoding() override = default;

/**
* GetCoeffsValue
* @return the un-encoded scalar
*/
const std::vector<int64_t>& GetCoefPackedValue() const {
const std::vector<int64_t>& GetCoefPackedValue() const override {
return value;
}

/**
* SetIntVectorValue
* @param val integer vector to initialize the plaintext
*/
void SetIntVectorValue(const std::vector<int64_t>& val) {
void SetIntVectorValue(const std::vector<int64_t>& val) override {
value = val;
}

/**
* Encode the plaintext into the Poly
* @return true on success
*/
bool Encode();
bool Encode() override;

/**
* Decode the Poly into the string
* @return true on success
*/
bool Decode();

/**
* GetEncodingType
* @return this is a COEF_PACKED_ENCODING encoding
*/
PlaintextEncodings GetEncodingType() const {
return COEF_PACKED_ENCODING;
}
*/
bool Decode() override;

/**
* Get length of the plaintext
*
* @return number of elements in this plaintext
*/
size_t GetLength() const {
size_t GetLength() const override {
return value.size();
}

/**
* SetLength of the plaintext to the given size
* @param siz
*/
void SetLength(size_t siz) {
void SetLength(size_t siz) override {
value.resize(siz);
}

/**
* Method to compare two plaintext to test for equivalence
* Testing that the plaintexts are of the same type done in operator==
*
* @param other - the other plaintext to compare to.
* @return whether the two plaintext are equivalent.
*/
bool CompareTo(const PlaintextImpl& other) const {
const auto& oth = static_cast<const CoefPackedEncoding&>(other);
return oth.value == this->value;
}

/**
* PrintValue - used by operator<< for this object
* @param out
*/
void PrintValue(std::ostream& out) const {
// for sanity's sake, trailing zeros get elided into "..."
out << "(";
size_t i = value.size();
while (--i > 0)
if (value[i] != 0)
break;

for (size_t j = 0; j <= i; j++)
out << ' ' << value[j];

out << " ... )";
}
};

} /* namespace lbcrypto */
Expand Down
48 changes: 24 additions & 24 deletions src/pke/include/encoding/encodingparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class EncodingParamsImpl : public lbcrypto::Serializable {
/**
* Destructor.
*/
virtual ~EncodingParamsImpl() {}
virtual ~EncodingParamsImpl() = default;

// ACCESSORS

Expand Down Expand Up @@ -253,29 +253,6 @@ class EncodingParamsImpl : public lbcrypto::Serializable {
return !(*this == other);
}

private:
std::ostream& doprint(std::ostream& out) const {
out << "[p=" << m_plaintextModulus << " rootP =" << m_plaintextRootOfUnity << " bigP =" << m_plaintextBigModulus
<< " rootBigP =" << m_plaintextBigRootOfUnity << " g=" << m_plaintextGenerator << " L=" << m_batchSize
<< "]";
return out;
}

// plaintext modulus that is used by all schemes
PlaintextModulus m_plaintextModulus;
// root of unity for plaintext modulus
NativeInteger m_plaintextRootOfUnity;
// big plaintext modulus that is used for arbitrary cyclotomics
NativeInteger m_plaintextBigModulus;
// root of unity for big plaintext modulus
NativeInteger m_plaintextBigRootOfUnity;
// plaintext generator is used for packed encoding (to find the correct
// automorphism index)
uint32_t m_plaintextGenerator;
// maximum batch size used by EvalSumKeyGen for packed encoding
uint32_t m_batchSize;

public:
template <class Archive>
void save(Archive& ar, std::uint32_t const version) const {
ar(::cereal::make_nvp("m", m_plaintextModulus));
Expand Down Expand Up @@ -306,6 +283,29 @@ class EncodingParamsImpl : public lbcrypto::Serializable {
static uint32_t SerializedVersion() {
return 1;
}

protected:
std::ostream& doprint(std::ostream& out) const {
out << "[p=" << m_plaintextModulus << " rootP =" << m_plaintextRootOfUnity << " bigP =" << m_plaintextBigModulus
<< " rootBigP =" << m_plaintextBigRootOfUnity << " g=" << m_plaintextGenerator << " L=" << m_batchSize
<< "]";
return out;
}

private:
// plaintext modulus that is used by all schemes
PlaintextModulus m_plaintextModulus;
// root of unity for plaintext modulus
NativeInteger m_plaintextRootOfUnity;
// big plaintext modulus that is used for arbitrary cyclotomics
NativeInteger m_plaintextBigModulus;
// root of unity for big plaintext modulus
NativeInteger m_plaintextBigRootOfUnity;
// plaintext generator is used for packed encoding (to find the correct
// automorphism index)
uint32_t m_plaintextGenerator;
// maximum batch size used by EvalSumKeyGen for packed encoding
uint32_t m_batchSize;
};

inline std::ostream& operator<<(std::ostream& out, const std::shared_ptr<EncodingParamsImpl>& o) {
Expand Down
Loading