From 8a5a5d7f5ef30c45ca32e4c734be69d99e549efb Mon Sep 17 00:00:00 2001 From: Dmitriy Suponitskiy Date: Mon, 25 Nov 2024 17:30:25 -0500 Subject: [PATCH 1/3] Corrected operator== for the family of cryptoparameters classes --- .../schemebase/base-cryptoparameters.h | 22 +++++-- .../schemebase/rlwe-cryptoparameters.h | 58 +++++++++---------- .../include/schemerns/rns-cryptoparameters.h | 38 ++++++------ 3 files changed, 64 insertions(+), 54 deletions(-) diff --git a/src/pke/include/schemebase/base-cryptoparameters.h b/src/pke/include/schemebase/base-cryptoparameters.h index 76a1bf173..1c0ad4989 100644 --- a/src/pke/include/schemebase/base-cryptoparameters.h +++ b/src/pke/include/schemebase/base-cryptoparameters.h @@ -59,9 +59,9 @@ class CryptoParametersBase : public Serializable { using TugType = typename Element::TugType; public: - CryptoParametersBase() {} + CryptoParametersBase() = default; - virtual ~CryptoParametersBase() {} + virtual ~CryptoParametersBase() = default; /** * Returns the value of plaintext modulus p @@ -99,11 +99,11 @@ class CryptoParametersBase : public Serializable { m_encodingParams->SetPlaintextModulus(plaintextModulus); } - virtual bool operator==(const CryptoParametersBase& cmp) const { - return *m_encodingParams == *(cmp.GetEncodingParams()) && *m_params == *(cmp.GetElementParams()); + bool operator==(const CryptoParametersBase& rhs) const { + return CompareTo(rhs); } - virtual bool operator!=(const CryptoParametersBase& cmp) const { - return !(*this == cmp); + bool operator!=(const CryptoParametersBase& rhs) const { + return !(*this == rhs); } /** @@ -194,6 +194,16 @@ class CryptoParametersBase : public Serializable { m_params = newElemParms; } + /** + * @brief CompareTo() is a method to compare two CryptoParametersBase objects. It is called by operator==() + * + * @param rhs - the other CryptoParametersBase object to compare to. + * @return whether the two CryptoParametersBase objects are equivalent. + */ + virtual bool CompareTo(const CryptoParametersBase& rhs) const { + return (*m_encodingParams == *(rhs.GetEncodingParams()) && *m_params == *(rhs.GetElementParams())); + } + virtual void PrintParameters(std::ostream& out) const { out << "Element Parameters: " << *m_params << std::endl; out << "Encoding Parameters: " << *m_encodingParams << std::endl; diff --git a/src/pke/include/schemebase/rlwe-cryptoparameters.h b/src/pke/include/schemebase/rlwe-cryptoparameters.h index ef5b15cfc..eb8192815 100644 --- a/src/pke/include/schemebase/rlwe-cryptoparameters.h +++ b/src/pke/include/schemebase/rlwe-cryptoparameters.h @@ -127,9 +127,9 @@ class CryptoParametersRLWE : public CryptoParametersBase { } /** - * Destructor + * Virtual Destructor */ - virtual ~CryptoParametersRLWE() {} + ~CryptoParametersRLWE() = default; /** * Returns the value of standard deviation r for discrete Gaussian @@ -414,33 +414,6 @@ class CryptoParametersRLWE : public CryptoParametersBase { m_thresholdNumOfParties = thresholdNumOfParties; } - /** - * == operator to compare to this instance of CryptoParametersRLWE object. - * - * @param &rhs CryptoParameters to check equality against. - */ - bool operator==(const CryptoParametersBase& rhs) const { - const auto* el = dynamic_cast*>(&rhs); - - if (el == nullptr) - return false; - - return CryptoParametersBase::operator==(*el) && - this->GetPlaintextModulus() == el->GetPlaintextModulus() && - *this->GetElementParams() == *el->GetElementParams() && - *this->GetEncodingParams() == *el->GetEncodingParams() && - m_distributionParameter == el->GetDistributionParameter() && - m_assuranceMeasure == el->GetAssuranceMeasure() && m_noiseScale == el->GetNoiseScale() && - m_digitSize == el->GetDigitSize() && m_secretKeyDist == el->GetSecretKeyDist() && - m_stdLevel == el->GetStdLevel() && m_maxRelinSkDeg == el->GetMaxRelinSkDeg() && - m_PREMode == el->GetPREMode() && m_multipartyMode == el->GetMultipartyMode() && - m_executionMode == el->GetExecutionMode() && - m_floodingDistributionParameter == el->GetFloodingDistributionParameter() && - m_statisticalSecurity == el->GetStatisticalSecurity() && - m_numAdversarialQueries == el->GetNumAdversarialQueries() && - m_thresholdNumOfParties == el->GetThresholdNumOfParties(); - } - void PrintParameters(std::ostream& os) const { CryptoParametersBase::PrintParameters(os); @@ -541,6 +514,33 @@ class CryptoParametersRLWE : public CryptoParametersBase { double m_numAdversarialQueries = 1; usint m_thresholdNumOfParties = 1; + + /** + * @brief CompareTo() is a method to compare two CryptoParametersRLWE objects. + * It is called by CryptoParametersBase::operator==() + * @param rhs - the other CryptoParametersRLWE object to compare to. + * @return whether the two CryptoParametersRLWE objects are equivalent. + */ + bool CompareTo(const CryptoParametersBase& rhs) const override { + const auto* el = dynamic_cast*>(&rhs); + if (el == nullptr) + return false; + + return CryptoParametersBase::CompareTo(*el) && + this->GetPlaintextModulus() == el->GetPlaintextModulus() && + *this->GetElementParams() == *el->GetElementParams() && + *this->GetEncodingParams() == *el->GetEncodingParams() && + m_distributionParameter == el->GetDistributionParameter() && + m_assuranceMeasure == el->GetAssuranceMeasure() && m_noiseScale == el->GetNoiseScale() && + m_digitSize == el->GetDigitSize() && m_secretKeyDist == el->GetSecretKeyDist() && + m_stdLevel == el->GetStdLevel() && m_maxRelinSkDeg == el->GetMaxRelinSkDeg() && + m_PREMode == el->GetPREMode() && m_multipartyMode == el->GetMultipartyMode() && + m_executionMode == el->GetExecutionMode() && + m_floodingDistributionParameter == el->GetFloodingDistributionParameter() && + m_statisticalSecurity == el->GetStatisticalSecurity() && + m_numAdversarialQueries == el->GetNumAdversarialQueries() && + m_thresholdNumOfParties == el->GetThresholdNumOfParties(); + } }; } // namespace lbcrypto diff --git a/src/pke/include/schemerns/rns-cryptoparameters.h b/src/pke/include/schemerns/rns-cryptoparameters.h index 0d07d7db4..b252b9b56 100644 --- a/src/pke/include/schemerns/rns-cryptoparameters.h +++ b/src/pke/include/schemerns/rns-cryptoparameters.h @@ -140,7 +140,25 @@ class CryptoParametersRNS : public CryptoParametersRLWE { m_MPIntBootCiphertextCompressionLevel = mPIntBootCiphertextCompressionLevel; } - virtual ~CryptoParametersRNS() {} + ~CryptoParametersRNS() = default; + + /** + * @brief CompareTo() is a method to compare two CryptoParametersRNS objects. + * It is called by CryptoParametersBase::operator==() + * @param rhs - the other CryptoParametersRNS object to compare to. + * @return whether the two CryptoParametersRNS objects are equivalent. + */ + bool CompareTo(const CryptoParametersBase& rhs) const override { + const auto* el = dynamic_cast(&rhs); + if (el == nullptr) + return false; + + return CryptoParametersRLWE::CompareTo(rhs) && m_scalTechnique == el->GetScalingTechnique() && + m_ksTechnique == el->GetKeySwitchTechnique() && m_multTechnique == el->GetMultiplicationTechnique() && + m_encTechnique == el->GetEncryptionTechnique() && m_numPartQ == el->GetNumPartQ() && + m_auxBits == el->GetAuxBits() && m_extraBits == el->GetExtraBits() && m_PREMode == el->GetPREMode() && + m_multipartyMode == el->GetMultipartyMode() && m_executionMode == el->GetExecutionMode(); + } public: /** @@ -183,24 +201,6 @@ class CryptoParametersRNS : public CryptoParametersRLWE { return static_cast(NoiseFlooding::MULTIPARTY_MOD_SIZE * NoiseFlooding::NUM_MODULI_MULTIPARTY); } - /** - * == operator to compare to this instance of CryptoParametersBase object. - * - * @param &rhs CryptoParameters to check equality against. - */ - bool operator==(const CryptoParametersBase& rhs) const override { - const auto* el = dynamic_cast(&rhs); - - if (el == nullptr) - return false; - - return CryptoParametersBase::operator==(rhs) && m_scalTechnique == el->GetScalingTechnique() && - m_ksTechnique == el->GetKeySwitchTechnique() && m_multTechnique == el->GetMultiplicationTechnique() && - m_encTechnique == el->GetEncryptionTechnique() && m_numPartQ == el->GetNumPartQ() && - m_auxBits == el->GetAuxBits() && m_extraBits == el->GetExtraBits() && m_PREMode == el->GetPREMode() && - m_multipartyMode == el->GetMultipartyMode() && m_executionMode == el->GetExecutionMode(); - } - void PrintParameters(std::ostream& os) const override { CryptoParametersBase::PrintParameters(os); } From 11e13198c80cf331a933ba16318c4caef4a1d38e Mon Sep 17 00:00:00 2001 From: Dmitriy Suponitskiy Date: Mon, 25 Nov 2024 23:26:30 -0500 Subject: [PATCH 2/3] Fixed syntax for multiple virtual functions, virtual destructors, made some print functions protected to force users to use operator<<(), fixed multiple issues, etc. --- src/pke/include/encoding/ckkspackedencoding.h | 72 ++++++-------- src/pke/include/encoding/coefpackedencoding.h | 99 ++++++++++--------- src/pke/include/encoding/packedencoding.h | 92 +++++++++-------- src/pke/include/encoding/plaintext.h | 98 ++++++++++-------- src/pke/include/encoding/stringencoding.h | 54 +++++----- .../schemebase/base-cryptoparameters.h | 13 ++- .../schemebase/rlwe-cryptoparameters.h | 26 ++--- .../include/schemerns/rns-cryptoparameters.h | 10 +- 8 files changed, 234 insertions(+), 230 deletions(-) diff --git a/src/pke/include/encoding/ckkspackedencoding.h b/src/pke/include/encoding/ckkspackedencoding.h index 4a3237028..c53698a6b 100644 --- a/src/pke/include/encoding/ckkspackedencoding.h +++ b/src/pke/include/encoding/ckkspackedencoding.h @@ -64,7 +64,7 @@ class CKKSPackedEncoding : public PlaintextImpl { std::is_same::value || std::is_same::value, bool>::type = true> - CKKSPackedEncoding(std::shared_ptr vp, EncodingParams ep) : PlaintextImpl(vp, ep, CKKSRNS_SCHEME) { + CKKSPackedEncoding(std::shared_ptr vp, EncodingParams ep) : PlaintextImpl(vp, ep, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME) { this->slots = GetDefaultSlotSize(); if (this->slots > (GetElementRingDimension() / 2)) { OPENFHE_THROW("The number of slots cannot be larger than half of ring dimension"); @@ -83,7 +83,7 @@ class CKKSPackedEncoding : public PlaintextImpl { bool>::type = true> CKKSPackedEncoding(std::shared_ptr vp, EncodingParams ep, const std::vector>& coeffs, size_t noiseScaleDeg, uint32_t level, double scFact, size_t slots) - : PlaintextImpl(vp, ep, CKKSRNS_SCHEME), value(coeffs) { + : PlaintextImpl(vp, ep, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME), value(coeffs) { // validate the number of slots if ((slots & (slots - 1)) != 0) { OPENFHE_THROW("The number of slots should be a power of two"); @@ -109,7 +109,7 @@ class CKKSPackedEncoding : public PlaintextImpl { * @param rhs - The input object to copy. */ explicit CKKSPackedEncoding(const std::vector>& rhs, size_t slots) - : PlaintextImpl(std::shared_ptr(0), nullptr, CKKSRNS_SCHEME), value(rhs) { + : PlaintextImpl(std::shared_ptr(0), nullptr, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME), value(rhs) { // validate the number of slots if ((slots & (slots - 1)) != 0) { OPENFHE_THROW("The number of slots should be a power of two"); @@ -128,7 +128,7 @@ class CKKSPackedEncoding : public PlaintextImpl { /** * @brief Default empty constructor with empty uninitialized data elements. */ - CKKSPackedEncoding() : PlaintextImpl(std::shared_ptr(0), nullptr, CKKSRNS_SCHEME) { + CKKSPackedEncoding() : PlaintextImpl(std::shared_ptr(0), nullptr, CKKS_PACKED_ENCODING, CKKSRNS_SCHEME) { this->slots = GetDefaultSlotSize(); if (this->slots > (GetElementRingDimension() / 2)) { OPENFHE_THROW("The number of slots cannot be larger than half of ring dimension"); @@ -147,7 +147,7 @@ class CKKSPackedEncoding : public PlaintextImpl { OPENFHE_THROW("CKKSPackedEncoding::Decode() is not implemented. Use CKKSPackedEncoding::Decode(...) instead."); } - bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode); + bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode) override; const std::vector>& GetCKKSPackedValue() const override { return value; @@ -175,14 +175,6 @@ class CKKSPackedEncoding : public PlaintextImpl { const std::vector& b, const std::vector& mods); - /** - * GetEncodingType - * @return CKKS_PACKED_ENCODING - */ - PlaintextEncodings GetEncodingType() const override { - return CKKS_PACKED_ENCODING; - } - /** * Get method to return the length of plaintext * @@ -215,40 +207,16 @@ class CKKSPackedEncoding : public PlaintextImpl { value.resize(siz); } - /** - * Method to compare two plaintext to test for equivalence. This method does - * not test that the plaintext are of the same type. - * - * @param other - the other plaintext to compare to. - * @return whether the two plaintext are equivalent. - */ - bool CompareTo(const PlaintextImpl& other) const override { - const auto& rv = static_cast(other); - return this->value == rv.value; - } - /** * @brief Destructor method. */ static void Destroy(); - void PrintValue(std::ostream& out) const override { - // for sanity's sake, trailing zeros get elided into "..." - // out.precision(15); - out << "("; - size_t i = value.size(); - while (--i > 0) - if (value[i] != std::complex(0, 0)) - break; - - for (size_t j = 0; j <= i; j++) { - out << value[j].real() << ", "; - } - - out << " ... ); "; - out << "Estimated precision: " << encodingParams->GetPlaintextModulus() - m_logError << " bits" << std::endl; - } - + /** + * @brief GetFormattedValues() is called by operator<< and requires a precision as an argument + * @param precision number of decimal digits of precision to print + * @return string with all values and "estimated precision" + */ std::string GetFormattedValues(int64_t precision) const override { std::stringstream ss; ss << "("; @@ -279,10 +247,30 @@ class CKKSPackedEncoding : public PlaintextImpl { double m_logError = 0; protected: + void PrintValue(std::ostream& out) const override { + out << GetFormattedValues(8) << std::endl; + } + usint GetDefaultSlotSize() { auto batchSize = GetEncodingParams()->GetBatchSize(); return (0 == batchSize) ? GetElementRingDimension() / 2 : batchSize; } + + /** + * Method to compare two plaintext to test for equivalence. This method does + * not test that the plaintext are of the same type. + * + * @param rhs - the other plaintext to compare to. + * @return whether the two plaintext are equivalent. + */ + bool CompareTo(const PlaintextImpl& rhs) const override { + const auto* el = dynamic_cast(&rhs); + if (el == nullptr) + return false; + + return this->value == el->value; + } + /** * Set modulus and recalculates the vector values to fit the modulus * diff --git a/src/pke/include/encoding/coefpackedencoding.h b/src/pke/include/encoding/coefpackedencoding.h index 089bbe1a8..f89f19808 100644 --- a/src/pke/include/encoding/coefpackedencoding.h +++ b/src/pke/include/encoding/coefpackedencoding.h @@ -48,13 +48,54 @@ namespace lbcrypto { class CoefPackedEncoding : public PlaintextImpl { std::vector value; +protected: + /** + * @brief PrintValue() is called by operator<< + * @param out stream to print to + */ + void PrintValue(std::ostream& out) const override { + out << "("; + + // for sanity's sake: get rid of all trailing zeroes and print "..." instead + size_t i = value.size(); + bool allZeroes = true; + while (i > 0) { + --i; + if (value[i] != 0) { + allZeroes = false; + break; + } + } + + if (allZeroes == false) { + for (size_t j = 0; j <= i; ++j) + out << value[j] << ", "; + } + out << "... )"; + } + + /** + * Method to compare two plaintext to test for equivalence + * Testing that the plaintexts are of the same type done in operator== + * + * @param rhs - the other plaintext to compare to. + * @return whether the two plaintext are equivalent. + */ + bool CompareTo(const PlaintextImpl& rhs) const override { + const auto* el = dynamic_cast(&rhs); + if (el == nullptr) + return false; + + return this->value == el->value; + } + public: template ::value || std::is_same::value || std::is_same::value, bool>::type = true> CoefPackedEncoding(std::shared_ptr vp, EncodingParams ep, SCHEME schemeId = SCHEME::INVALID_SCHEME) - : PlaintextImpl(vp, ep, schemeId) {} + : PlaintextImpl(vp, ep, COEF_PACKED_ENCODING, schemeId) {} template ::value || std::is_same::value || @@ -62,15 +103,15 @@ class CoefPackedEncoding : public PlaintextImpl { bool>::type = true> CoefPackedEncoding(std::shared_ptr vp, EncodingParams ep, const std::vector& coeffs, SCHEME schemeId = SCHEME::INVALID_SCHEME) - : PlaintextImpl(vp, ep, schemeId), value(coeffs) {} + : PlaintextImpl(vp, ep, COEF_PACKED_ENCODING, schemeId), value(coeffs) {} - virtual ~CoefPackedEncoding() = default; + ~CoefPackedEncoding() override = default; /** * GetCoeffsValue * @return the un-encoded scalar */ - const std::vector& GetCoefPackedValue() const { + const std::vector& GetCoefPackedValue() const override { return value; } @@ -78,7 +119,7 @@ class CoefPackedEncoding : public PlaintextImpl { * SetIntVectorValue * @param val integer vector to initialize the plaintext */ - void SetIntVectorValue(const std::vector& val) { + void SetIntVectorValue(const std::vector& val) override { value = val; } @@ -86,28 +127,20 @@ class CoefPackedEncoding : public PlaintextImpl { * Encode the plaintext into the Poly * @return true on success */ - bool Encode(); + bool Encode() override; /** * Decode the Poly into the string * @return true on success - */ - bool Decode(); - - /** - * GetEncodingType - * @return this is a COEF_PACKED_ENCODING encoding - */ - PlaintextEncodings GetEncodingType() const { - return COEF_PACKED_ENCODING; - } + */ + bool Decode() override; /** * Get length of the plaintext * * @return number of elements in this plaintext */ - size_t GetLength() const { + size_t GetLength() const override { return value.size(); } @@ -115,39 +148,9 @@ class CoefPackedEncoding : public PlaintextImpl { * SetLength of the plaintext to the given size * @param siz */ - void SetLength(size_t siz) { + void SetLength(size_t siz) override { value.resize(siz); } - - /** - * Method to compare two plaintext to test for equivalence - * Testing that the plaintexts are of the same type done in operator== - * - * @param other - the other plaintext to compare to. - * @return whether the two plaintext are equivalent. - */ - bool CompareTo(const PlaintextImpl& other) const { - const auto& oth = static_cast(other); - return oth.value == this->value; - } - - /** - * PrintValue - used by operator<< for this object - * @param out - */ - void PrintValue(std::ostream& out) const { - // for sanity's sake, trailing zeros get elided into "..." - out << "("; - size_t i = value.size(); - while (--i > 0) - if (value[i] != 0) - break; - - for (size_t j = 0; j <= i; j++) - out << ' ' << value[j]; - - out << " ... )"; - } }; } /* namespace lbcrypto */ diff --git a/src/pke/include/encoding/packedencoding.h b/src/pke/include/encoding/packedencoding.h index 97a02c4a0..6d76f1467 100644 --- a/src/pke/include/encoding/packedencoding.h +++ b/src/pke/include/encoding/packedencoding.h @@ -70,21 +70,21 @@ class PackedEncoding : public PlaintextImpl { std::is_same::value || std::is_same::value, bool>::type = true> - PackedEncoding(std::shared_ptr vp, EncodingParams ep) : PlaintextImpl(vp, ep) {} + PackedEncoding(std::shared_ptr vp, EncodingParams ep) : PlaintextImpl(vp, ep, PACKED_ENCODING) {} template ::value || std::is_same::value || std::is_same::value, bool>::type = true> PackedEncoding(std::shared_ptr vp, EncodingParams ep, const std::vector& coeffs) - : PlaintextImpl(vp, ep), value(coeffs) {} + : PlaintextImpl(vp, ep, PACKED_ENCODING), value(coeffs) {} template ::value || std::is_same::value || std::is_same::value, bool>::type = true> PackedEncoding(std::shared_ptr vp, EncodingParams ep, std::initializer_list coeffs) - : PlaintextImpl(vp, ep), value(coeffs) {} + : PlaintextImpl(vp, ep, PACKED_ENCODING), value(coeffs) {} /** * @brief Constructs a container with a copy of each of the elements in rhs, @@ -92,7 +92,7 @@ class PackedEncoding : public PlaintextImpl { * @param rhs - The input object to copy. */ explicit PackedEncoding(const std::vector& rhs) - : PlaintextImpl(std::shared_ptr(0), nullptr), value(rhs) {} + : PlaintextImpl(std::shared_ptr(0), nullptr, PACKED_ENCODING), value(rhs) {} /** * @brief Constructs a container with a copy of each of the elements in il, in @@ -100,22 +100,22 @@ class PackedEncoding : public PlaintextImpl { * @param arr the list to copy. */ PackedEncoding(std::initializer_list arr) - : PlaintextImpl(std::shared_ptr(0), nullptr), value(arr) {} + : PlaintextImpl(std::shared_ptr(0), nullptr, PACKED_ENCODING), value(arr) {} /** * @brief Default empty constructor with empty uninitialized data elements. */ - PackedEncoding() : PlaintextImpl(std::shared_ptr(0), nullptr), value() {} + PackedEncoding() : PlaintextImpl(std::shared_ptr(0), nullptr, PACKED_ENCODING), value() {} static usint GetAutomorphismGenerator(usint m) { return m_automorphismGenerator[m]; } - bool Encode(); + bool Encode() override; - bool Decode(); + bool Decode() override; - const std::vector& GetPackedValue() const { + const std::vector& GetPackedValue() const override { return value; } @@ -123,24 +123,16 @@ class PackedEncoding : public PlaintextImpl { * SetIntVectorValue * @param val integer vector to initialize the plaintext */ - void SetIntVectorValue(const std::vector& val) { + void SetIntVectorValue(const std::vector& val) override { value = val; } - /** - * GetEncodingType - * @return PACKED_ENCODING - */ - PlaintextEncodings GetEncodingType() const { - return PACKED_ENCODING; - } - /** * Get method to return the length of plaintext * * @return the length of the plaintext in terms of the number of bits. */ - size_t GetLength() const { + size_t GetLength() const override { return value.size(); } @@ -164,39 +156,53 @@ class PackedEncoding : public PlaintextImpl { * SetLength of the plaintext to the given size * @param siz */ - void SetLength(size_t siz) { + void SetLength(size_t siz) override { value.resize(siz); } /** - * Method to compare two plaintext to test for equivalence. This method does - * not test that the plaintext are of the same type. - * - * @param other - the other plaintext to compare to. - * @return whether the two plaintext are equivalent. - */ - bool CompareTo(const PlaintextImpl& other) const { - const auto& rv = static_cast(other); - return this->value == rv.value; - } - - /** - * @brief Destructor method. - */ + * @brief Destructor method. + */ static void Destroy(); - void PrintValue(std::ostream& out) const { - // for sanity's sake, trailing zeros get elided into "..." +protected: + /** + * @brief PrintValue() is called by operator<< + * @param out stream to print to + */ + void PrintValue(std::ostream& out) const override { out << "("; - size_t i = value.size(); - while (--i > 0) - if (value[i] != 0) + // for sanity's sake: get rid of all trailing zeroes and print "..." instead + size_t i = value.size(); + bool allZeroes = true; + while (i > 0) { + --i; + if (value[i] != 0) { + allZeroes = false; break; + } + } + + if (allZeroes == false) { + for (size_t j = 0; j <= i; ++j) + out << value[j] << ", "; + } + out << "... )"; + } - for (size_t j = 0; j <= i; j++) - out << ' ' << value[j]; - - out << " ... )"; + /** + * Method to compare two plaintext to test for equivalence. This method does + * not test that the plaintext are of the same type. + * + * @param rhs - the other plaintext to compare to. + * @return whether the two plaintext are equivalent. + */ + bool CompareTo(const PlaintextImpl& rhs) const override { + const auto* el = dynamic_cast(&rhs); + if (el == nullptr) + return false; + + return this->value == el->value; } private: diff --git a/src/pke/include/encoding/plaintext.h b/src/pke/include/encoding/plaintext.h index 9e419c599..edba4ba12 100644 --- a/src/pke/include/encoding/plaintext.h +++ b/src/pke/include/encoding/plaintext.h @@ -83,32 +83,52 @@ class PlaintextImpl { size_t level = 0; size_t noiseScaleDeg = 1; usint slots = 0; + PlaintextEncodings ptxtEncoding = INVALID_ENCODING; SCHEME schemeID; +protected: + /** + * @brief PrintValue() is called by operator<< + * @param out + */ + virtual void PrintValue(std::ostream& out) const = 0; + + /** + * Method to compare two plaintext to test for equivalence. + * This method is called by operator== + * + * @param other - the other plaintext to compare to. + * @return whether the two plaintext are equivalent. + */ + virtual bool CompareTo(const PlaintextImpl& other) const = 0; + public: - PlaintextImpl(const std::shared_ptr& vp, EncodingParams ep, SCHEME schemeTag = SCHEME::INVALID_SCHEME, - bool isEncoded = false) + PlaintextImpl(const std::shared_ptr& vp, EncodingParams ep, PlaintextEncodings encoding, + SCHEME schemeTag = SCHEME::INVALID_SCHEME, bool isEncoded = false) : isEncoded(isEncoded), typeFlag(IsPoly), encodingParams(std::move(ep)), encodedVector(vp, Format::COEFFICIENT), + ptxtEncoding(encoding), schemeID(schemeTag) {} - PlaintextImpl(const std::shared_ptr& vp, EncodingParams ep, + PlaintextImpl(const std::shared_ptr& vp, EncodingParams ep, PlaintextEncodings encoding, SCHEME schemeTag = SCHEME::INVALID_SCHEME, bool isEncoded = false) : isEncoded(isEncoded), typeFlag(IsNativePoly), encodingParams(std::move(ep)), encodedNativeVector(vp, Format::COEFFICIENT), + ptxtEncoding(encoding), schemeID(schemeTag) {} - PlaintextImpl(const std::shared_ptr& vp, EncodingParams ep, + PlaintextImpl(const std::shared_ptr& vp, EncodingParams ep, PlaintextEncodings encoding, SCHEME schemeTag = SCHEME::INVALID_SCHEME, bool isEncoded = false) : isEncoded(isEncoded), typeFlag(IsDCRTPoly), encodingParams(std::move(ep)), encodedVector(vp, Format::COEFFICIENT), encodedVectorDCRT(vp, Format::COEFFICIENT), + ptxtEncoding(encoding), schemeID(schemeTag) {} PlaintextImpl(const PlaintextImpl& rhs) @@ -122,6 +142,7 @@ class PlaintextImpl { level(rhs.level), noiseScaleDeg(rhs.noiseScaleDeg), slots(rhs.slots), + ptxtEncoding(rhs.ptxtEncoding), schemeID(rhs.schemeID) {} PlaintextImpl(PlaintextImpl&& rhs) @@ -135,15 +156,18 @@ class PlaintextImpl { level(rhs.level), noiseScaleDeg(rhs.noiseScaleDeg), slots(rhs.slots), + ptxtEncoding(rhs.ptxtEncoding), schemeID(rhs.schemeID) {} - virtual ~PlaintextImpl() {} + virtual ~PlaintextImpl() = default; /** * GetEncodingType * @return Encoding type used by this plaintext */ - virtual PlaintextEncodings GetEncodingType() const = 0; + PlaintextEncodings GetEncodingType() const { + return ptxtEncoding; + } /** * Get the scaling factor of the plaintext for CKKS-based plaintexts. @@ -203,10 +227,13 @@ class PlaintextImpl { virtual bool Encode() = 0; /** - * Decode the polynomial into the plaintext + * @brief Decode the polynomial into the plaintext * @return */ virtual bool Decode() = 0; + virtual bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode) { + OPENFHE_THROW("Not implemented"); + } /** * Calculate and return lower bound that can be encoded with the plaintext @@ -358,7 +385,7 @@ class PlaintextImpl { OPENFHE_THROW("not a packed coefficient vector"); } virtual const std::vector& GetPackedValue() const { - OPENFHE_THROW("not a packed coefficient vector"); + OPENFHE_THROW("not a packed vector"); } virtual const std::vector>& GetCKKSPackedValue() const { OPENFHE_THROW("not a packed vector of complex numbers"); @@ -373,15 +400,6 @@ class PlaintextImpl { OPENFHE_THROW("does not support an int vector"); } - /** - * Method to compare two plaintext to test for equivalence. - * This method is called by operator== - * - * @param other - the other plaintext to compare to. - * @return whether the two plaintext are equivalent. - */ - virtual bool CompareTo(const PlaintextImpl& other) const = 0; - /** * operator== for plaintexts. This method makes sure the plaintexts are of * the same type. @@ -398,39 +416,33 @@ class PlaintextImpl { } /** - * operator<< for ostream integration - calls PrintValue - * @param out - * @param item - * @return - */ - friend std::ostream& operator<<(std::ostream& out, const PlaintextImpl& item); - - /** - * PrintValue is called by operator<< - * @param out - */ - virtual void PrintValue(std::ostream& out) const = 0; + * @brief operator<< for ostream integration - calls PrintValue() + * @param out + * @param item + * @return + */ + friend std::ostream& operator<<(std::ostream& out, const PlaintextImpl& item) { + item.PrintValue(out); + return out; + } + friend std::ostream& operator<<(std::ostream& out, const Plaintext& item) { + if (item) + out << *item; // Call the non-pointer version + else + OPENFHE_THROW("Cannot de-reference nullptr for printing"); + return out; + } /** - * GetFormattedValues() has a logic similar to PrintValue(), but requires a precision as an argument - * @param precision number of decimal digits of precision to print - * @return string with all values and "estimated precision" - */ + * @brief GetFormattedValues() is similar to PrintValue() and requires a precision as an argument + * @param precision number of decimal digits of precision to print + * @return string with all values + */ virtual std::string GetFormattedValues(int64_t precision) const { OPENFHE_THROW("not implemented"); } }; -inline std::ostream& operator<<(std::ostream& out, const PlaintextImpl& item) { - item.PrintValue(out); - return out; -} - -inline std::ostream& operator<<(std::ostream& out, const Plaintext& item) { - item->PrintValue(out); - return out; -} - inline bool operator==(const Plaintext& p1, const Plaintext& p2) { return *p1 == *p2; } diff --git a/src/pke/include/encoding/stringencoding.h b/src/pke/include/encoding/stringencoding.h index 8ab1fbe27..146d24683 100644 --- a/src/pke/include/encoding/stringencoding.h +++ b/src/pke/include/encoding/stringencoding.h @@ -53,25 +53,25 @@ class StringEncoding : public PlaintextImpl { std::is_same::value || std::is_same::value, bool>::type = true> - StringEncoding(std::shared_ptr vp, EncodingParams ep) : PlaintextImpl(vp, ep) {} + StringEncoding(std::shared_ptr vp, EncodingParams ep) : PlaintextImpl(vp, ep, STRING_ENCODING) {} template ::value || std::is_same::value || std::is_same::value, bool>::type = true> StringEncoding(std::shared_ptr vp, EncodingParams ep, const std::string& str) - : PlaintextImpl(vp, ep), ptx(str) {} + : PlaintextImpl(vp, ep, STRING_ENCODING), ptx(str) {} // TODO provide wide-character version (for unicode); right now this class // only supports strings of 7-bit ASCII characters - virtual ~StringEncoding() {} + ~StringEncoding() override = default; /** * GetStringValue * @return the un-encoded string */ - const std::string& GetStringValue() const { + const std::string& GetStringValue() const override { return ptx; } @@ -79,7 +79,7 @@ class StringEncoding : public PlaintextImpl { * SetStringValue * @param val to initialize the Plaintext */ - void SetStringValue(const std::string& value) { + void SetStringValue(const std::string& value) override { ptx = value; } @@ -87,48 +87,44 @@ class StringEncoding : public PlaintextImpl { * Encode the plaintext into the Poly * @return true on success */ - bool Encode(); + bool Encode() override; /** * Decode the Poly into the string * @return true on success */ - bool Decode(); - - /** - * GetEncodingType - * @return STRING_ENCODING - */ - PlaintextEncodings GetEncodingType() const { - return STRING_ENCODING; - } + bool Decode() override; /** * Get length of the plaintext * * @return number of elements in this plaintext */ - size_t GetLength() const { + size_t GetLength() const override { return ptx.size(); } +protected: /** - * Method to compare two plaintext to test for equivalence - * Testing that the plaintexts are of the same type done in operator== - * - * @param other - the other plaintext to compare to. - * @return whether the two plaintext are equivalent. - */ - bool CompareTo(const PlaintextImpl& other) const { - const auto& oth = static_cast(other); - return oth.ptx == this->ptx; + * Method to compare two plaintext to test for equivalence + * Testing that the plaintexts are of the same type done in operator== + * + * @param rhs - the other plaintext to compare to. + * @return whether the two plaintext are equivalent. + */ + bool CompareTo(const PlaintextImpl& rhs) const override { + const auto* el = dynamic_cast(&rhs); + if (el == nullptr) + return false; + + return this->ptx == el->ptx; } /** - * PrintValue - used by operator<< for this object - * @param out - */ - void PrintValue(std::ostream& out) const { + * PrintValue - used by operator<< for this object + * @param out + */ + void PrintValue(std::ostream& out) const override { out << ptx; } }; diff --git a/src/pke/include/schemebase/base-cryptoparameters.h b/src/pke/include/schemebase/base-cryptoparameters.h index 1c0ad4989..5ae25e2c8 100644 --- a/src/pke/include/schemebase/base-cryptoparameters.h +++ b/src/pke/include/schemebase/base-cryptoparameters.h @@ -68,7 +68,7 @@ class CryptoParametersBase : public Serializable { * * @return the plaintext modulus. */ - virtual const PlaintextModulus& GetPlaintextModulus() const { + PlaintextModulus GetPlaintextModulus() const { return m_encodingParams->GetPlaintextModulus(); } @@ -77,7 +77,7 @@ class CryptoParametersBase : public Serializable { * * @return the ring element parameters. */ - virtual const std::shared_ptr GetElementParams() const { + const std::shared_ptr GetElementParams() const { return m_params; } @@ -88,14 +88,14 @@ class CryptoParametersBase : public Serializable { * * @return the encoding parameters. */ - virtual const EncodingParams GetEncodingParams() const { + const EncodingParams GetEncodingParams() const { return m_encodingParams; } /** * Sets the value of plaintext modulus p */ - virtual void SetPlaintextModulus(const PlaintextModulus& plaintextModulus) { + void SetPlaintextModulus(PlaintextModulus plaintextModulus) { m_encodingParams->SetPlaintextModulus(plaintextModulus); } @@ -119,7 +119,7 @@ class CryptoParametersBase : public Serializable { return out; } - virtual usint GetDigitSize() const { + virtual uint32_t GetDigitSize() const { return 0; } @@ -167,7 +167,7 @@ class CryptoParametersBase : public Serializable { ar(::cereal::make_nvp("enp", m_encodingParams)); } - std::string SerializedObjectName() const { + std::string SerializedObjectName() const override { return "CryptoParametersBase"; } static uint32_t SerializedVersion() { @@ -209,7 +209,6 @@ class CryptoParametersBase : public Serializable { out << "Encoding Parameters: " << *m_encodingParams << std::endl; } -protected: // element-specific parameters std::shared_ptr m_params; diff --git a/src/pke/include/schemebase/rlwe-cryptoparameters.h b/src/pke/include/schemebase/rlwe-cryptoparameters.h index eb8192815..86e4ae44c 100644 --- a/src/pke/include/schemebase/rlwe-cryptoparameters.h +++ b/src/pke/include/schemebase/rlwe-cryptoparameters.h @@ -129,7 +129,7 @@ class CryptoParametersRLWE : public CryptoParametersBase { /** * Virtual Destructor */ - ~CryptoParametersRLWE() = default; + ~CryptoParametersRLWE() override = default; /** * Returns the value of standard deviation r for discrete Gaussian @@ -174,7 +174,7 @@ class CryptoParametersRLWE : public CryptoParametersBase { * * @return the digit size. */ - usint GetDigitSize() const { + uint32_t GetDigitSize() const override { return m_digitSize; } @@ -184,7 +184,7 @@ class CryptoParametersRLWE : public CryptoParametersBase { * * @return maximum power of secret key */ - uint32_t GetMaxRelinSkDeg() const { + uint32_t GetMaxRelinSkDeg() const override { return m_maxRelinSkDeg; } @@ -414,14 +414,6 @@ class CryptoParametersRLWE : public CryptoParametersBase { m_thresholdNumOfParties = thresholdNumOfParties; } - void PrintParameters(std::ostream& os) const { - CryptoParametersBase::PrintParameters(os); - - os << "Distrib parm " << GetDistributionParameter() << ", Assurance measure " << GetAssuranceMeasure() - << ", Noise scale " << GetNoiseScale() << ", Digit Size " << GetDigitSize() << ", SecretKeyDist " - << GetSecretKeyDist() << ", Standard security level " << GetStdLevel() << std::endl; - } - template void save(Archive& ar, std::uint32_t const version) const { ar(::cereal::base_class>(this)); @@ -465,7 +457,7 @@ class CryptoParametersRLWE : public CryptoParametersBase { m_dggFlooding.SetStd(m_floodingDistributionParameter); } - std::string SerializedObjectName() const { + std::string SerializedObjectName() const override { return "CryptoParametersRLWE"; } @@ -479,7 +471,7 @@ class CryptoParametersRLWE : public CryptoParametersBase { // noise scale PlaintextModulus m_noiseScale = 1; // digit size - usint m_digitSize = 1; + uint32_t m_digitSize = 1; // the highest power of secret key for which relinearization key is generated uint32_t m_maxRelinSkDeg = 2; // specifies whether the secret polynomials are generated from discrete @@ -541,6 +533,14 @@ class CryptoParametersRLWE : public CryptoParametersBase { m_numAdversarialQueries == el->GetNumAdversarialQueries() && m_thresholdNumOfParties == el->GetThresholdNumOfParties(); } + + void PrintParameters(std::ostream& os) const override { + CryptoParametersBase::PrintParameters(os); + + os << "Distrib parm " << GetDistributionParameter() << ", Assurance measure " << GetAssuranceMeasure() + << ", Noise scale " << GetNoiseScale() << ", Digit Size " << GetDigitSize() << ", SecretKeyDist " + << GetSecretKeyDist() << ", Standard security level " << GetStdLevel() << std::endl; + } }; } // namespace lbcrypto diff --git a/src/pke/include/schemerns/rns-cryptoparameters.h b/src/pke/include/schemerns/rns-cryptoparameters.h index b252b9b56..75588b682 100644 --- a/src/pke/include/schemerns/rns-cryptoparameters.h +++ b/src/pke/include/schemerns/rns-cryptoparameters.h @@ -140,7 +140,7 @@ class CryptoParametersRNS : public CryptoParametersRLWE { m_MPIntBootCiphertextCompressionLevel = mPIntBootCiphertextCompressionLevel; } - ~CryptoParametersRNS() = default; + ~CryptoParametersRNS() override = default; /** * @brief CompareTo() is a method to compare two CryptoParametersRNS objects. @@ -160,6 +160,10 @@ class CryptoParametersRNS : public CryptoParametersRLWE { m_multipartyMode == el->GetMultipartyMode() && m_executionMode == el->GetExecutionMode(); } + void PrintParameters(std::ostream& os) const override { + CryptoParametersRLWE::PrintParameters(os); + } + public: /** * Computes all tables needed for decryption, homomorphic multiplication and key switching. @@ -201,10 +205,6 @@ class CryptoParametersRNS : public CryptoParametersRLWE { return static_cast(NoiseFlooding::MULTIPARTY_MOD_SIZE * NoiseFlooding::NUM_MODULI_MULTIPARTY); } - void PrintParameters(std::ostream& os) const override { - CryptoParametersBase::PrintParameters(os); - } - ///////////////////////////////////// // PrecomputeCRTTables ///////////////////////////////////// From cb5a58b1aa058e14b361f7b430f9fc616226c035 Mon Sep 17 00:00:00 2001 From: Dmitriy Suponitskiy Date: Thu, 28 Nov 2024 00:18:23 -0500 Subject: [PATCH 3/3] Changes for additional print functions --- .../lattice/hal/default/ildcrtparams.h | 3 +- .../include/lattice/hal/default/ilparams.h | 2 +- src/pke/include/encoding/encodingparams.h | 48 +++++++++---------- src/pke/include/metadata.h | 22 ++++----- src/pke/unittest/utils/UnitTestMetadataTest.h | 20 ++++---- 5 files changed, 48 insertions(+), 47 deletions(-) diff --git a/src/core/include/lattice/hal/default/ildcrtparams.h b/src/core/include/lattice/hal/default/ildcrtparams.h index f85157476..284921758 100644 --- a/src/core/include/lattice/hal/default/ildcrtparams.h +++ b/src/core/include/lattice/hal/default/ildcrtparams.h @@ -345,7 +345,7 @@ class ILDCRTParams final : public ElemParams { return 1; } -private: +protected: std::ostream& doprint(std::ostream& out) const override { out << "ILDCRTParams "; ElemParams::doprint(out); @@ -355,6 +355,7 @@ class ILDCRTParams final : public ElemParams { return out << std::endl; } +private: // array of smaller ILParams std::vector> m_params; }; diff --git a/src/core/include/lattice/hal/default/ilparams.h b/src/core/include/lattice/hal/default/ilparams.h index ce18606fa..e08b1a732 100644 --- a/src/core/include/lattice/hal/default/ilparams.h +++ b/src/core/include/lattice/hal/default/ilparams.h @@ -152,7 +152,7 @@ class ILParamsImpl final : public ElemParams { return 1; } -private: +protected: std::ostream& doprint(std::ostream& out) const override { out << "ILParams "; ElemParams::doprint(out); diff --git a/src/pke/include/encoding/encodingparams.h b/src/pke/include/encoding/encodingparams.h index cd213c38f..e75c7f5e9 100644 --- a/src/pke/include/encoding/encodingparams.h +++ b/src/pke/include/encoding/encodingparams.h @@ -124,7 +124,7 @@ class EncodingParamsImpl : public lbcrypto::Serializable { /** * Destructor. */ - virtual ~EncodingParamsImpl() {} + virtual ~EncodingParamsImpl() = default; // ACCESSORS @@ -253,29 +253,6 @@ class EncodingParamsImpl : public lbcrypto::Serializable { return !(*this == other); } -private: - std::ostream& doprint(std::ostream& out) const { - out << "[p=" << m_plaintextModulus << " rootP =" << m_plaintextRootOfUnity << " bigP =" << m_plaintextBigModulus - << " rootBigP =" << m_plaintextBigRootOfUnity << " g=" << m_plaintextGenerator << " L=" << m_batchSize - << "]"; - return out; - } - - // plaintext modulus that is used by all schemes - PlaintextModulus m_plaintextModulus; - // root of unity for plaintext modulus - NativeInteger m_plaintextRootOfUnity; - // big plaintext modulus that is used for arbitrary cyclotomics - NativeInteger m_plaintextBigModulus; - // root of unity for big plaintext modulus - NativeInteger m_plaintextBigRootOfUnity; - // plaintext generator is used for packed encoding (to find the correct - // automorphism index) - uint32_t m_plaintextGenerator; - // maximum batch size used by EvalSumKeyGen for packed encoding - uint32_t m_batchSize; - -public: template void save(Archive& ar, std::uint32_t const version) const { ar(::cereal::make_nvp("m", m_plaintextModulus)); @@ -306,6 +283,29 @@ class EncodingParamsImpl : public lbcrypto::Serializable { static uint32_t SerializedVersion() { return 1; } + +protected: + std::ostream& doprint(std::ostream& out) const { + out << "[p=" << m_plaintextModulus << " rootP =" << m_plaintextRootOfUnity << " bigP =" << m_plaintextBigModulus + << " rootBigP =" << m_plaintextBigRootOfUnity << " g=" << m_plaintextGenerator << " L=" << m_batchSize + << "]"; + return out; + } + +private: + // plaintext modulus that is used by all schemes + PlaintextModulus m_plaintextModulus; + // root of unity for plaintext modulus + NativeInteger m_plaintextRootOfUnity; + // big plaintext modulus that is used for arbitrary cyclotomics + NativeInteger m_plaintextBigModulus; + // root of unity for big plaintext modulus + NativeInteger m_plaintextBigRootOfUnity; + // plaintext generator is used for packed encoding (to find the correct + // automorphism index) + uint32_t m_plaintextGenerator; + // maximum batch size used by EvalSumKeyGen for packed encoding + uint32_t m_batchSize; }; inline std::ostream& operator<<(std::ostream& out, const std::shared_ptr& o) { diff --git a/src/pke/include/metadata.h b/src/pke/include/metadata.h index 54ca7241e..8a08c6fea 100644 --- a/src/pke/include/metadata.h +++ b/src/pke/include/metadata.h @@ -92,20 +92,11 @@ class Metadata { } /** - * A method that prints the contents of metadata objects. - * Please override in subclasses to print all members. - */ - virtual std::ostream& print(std::ostream& out) const { - out << "[ ]" << std::endl; - return out; - } - - /** - * << operator implements by calling member method print. + * << operator implements by calling member method PrintMetadata. * This is a friend method and cannot be overriden by subclasses. */ friend std::ostream& operator<<(std::ostream& out, const Metadata& m) { - m.print(out); + m.PrintMetadata(out); return out; } @@ -139,6 +130,15 @@ class Metadata { static uint32_t SerializedVersion() { return 1; } + +protected: + /** + * A method that prints the contents of metadata objects. + * Please override in subclasses to print all members. + */ + virtual std::ostream& PrintMetadata(std::ostream& out) const { + OPENFHE_THROW("Not implemented"); + } }; } // end namespace lbcrypto diff --git a/src/pke/unittest/utils/UnitTestMetadataTest.h b/src/pke/unittest/utils/UnitTestMetadataTest.h index d4807d277..7c88c69a8 100644 --- a/src/pke/unittest/utils/UnitTestMetadataTest.h +++ b/src/pke/unittest/utils/UnitTestMetadataTest.h @@ -74,7 +74,7 @@ class MetadataTest : public Metadata { * the Clone method. * */ - std::shared_ptr Clone() const { + std::shared_ptr Clone() const override { auto mdata = std::make_shared(); mdata->m_s = this->m_s; return mdata; @@ -97,7 +97,7 @@ class MetadataTest : public Metadata { /** * Defines how to check equality between objects of this class. */ - bool operator==(const Metadata& mdata) const { + bool operator==(const Metadata& mdata) const override { try { const MetadataTest& mdataTest = dynamic_cast(mdata); return m_s == mdataTest.GetMetadata(); // All Metadata objects without @@ -108,14 +108,6 @@ class MetadataTest : public Metadata { } } - /** - * Defines how to print the contents of objects of this class. - */ - std::ostream& print(std::ostream& out) const { - out << "[ " << m_s << " ]"; - return out; - } - /** * save method for serialization */ @@ -201,6 +193,14 @@ class MetadataTest : public Metadata { } protected: + /** + * Defines how to print the contents of objects of this class. + */ + std::ostream& PrintMetadata(std::ostream& out) const override { + out << "[ " << m_s << " ]"; + return out; + } + std::string m_s; };