Skip to content

Commit 2834bfe

Browse files
authored
Merge pull request #1097 from hsanzg/linear_polys
Fix evaluation of linear functions with `EvalPoly`
2 parents 26bdb8c + 5594dd4 commit 2834bfe

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

src/pke/lib/scheme/ckksrns/ckksrns-advancedshe.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,12 +361,12 @@ std::shared_ptr<seriesPowers<DCRTPoly>> AdvancedSHECKKSRNS::EvalPowers(
361361
}
362362

363363
template <typename VectorDataType>
364-
Ciphertext<DCRTPoly> internalEvalPolyLinearWithPrecomp(std::vector<Ciphertext<DCRTPoly>>& powers,
365-
const std::vector<VectorDataType>& coefficients) {
366-
const uint32_t k = coefficients.size() - 1;
367-
if (k <= 1)
364+
static inline Ciphertext<DCRTPoly> internalEvalPolyLinearWithPrecomp(std::vector<Ciphertext<DCRTPoly>>& powers,
365+
const std::vector<VectorDataType>& coefficients) {
366+
if (coefficients.size() < 2)
368367
OPENFHE_THROW("The coefficients vector should contain at least 2 elements");
369368

369+
uint32_t k = coefficients.size() - 1;
370370
if (!IsNotEqualZero(coefficients[k]))
371371
OPENFHE_THROW("EvalPolyLinear: The highest-order coefficient cannot be set to 0.");
372372

src/pke/unittest/utckksrns/UnitTestCKKSrns.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,6 +1801,9 @@ class UTCKKSRNS : public ::testing::TestWithParam<TEST_CASE_UTCKKSRNS> {
18011801
// x + x^2 - x^3
18021802
// low-degree function to check linear implementation
18031803
std::vector<double> coefficients5{0, 1, 1, -1};
1804+
// 1 + 2x
1805+
// linear function to check linear implementation
1806+
std::vector<double> coefficients6{1.0, 2.0};
18041807

18051808
Plaintext plaintext1 = cc->MakeCKKSPackedPlaintext(input);
18061809

@@ -1819,6 +1822,9 @@ class UTCKKSRNS : public ::testing::TestWithParam<TEST_CASE_UTCKKSRNS> {
18191822
std::vector<std::complex<double>> output5{0.625, 0.847, 0.9809999999, 0.995125, 0.990543};
18201823
Plaintext plaintextResult5 = cc->MakeCKKSPackedPlaintext(output5);
18211824

1825+
std::vector<std::complex<double>> output6{2.0, 2.4, 2.8, 2.9, 2.86};
1826+
Plaintext plaintextResult6 = cc->MakeCKKSPackedPlaintext(output6);
1827+
18221828
// Generate encryption keys
18231829
KeyPair<Element> kp = cc->KeyGen();
18241830
// Generate multiplication keys
@@ -1878,6 +1884,16 @@ class UTCKKSRNS : public ::testing::TestWithParam<TEST_CASE_UTCKKSRNS> {
18781884
<< " - we get: " << results5->GetCKKSPackedValue();
18791885
checkEquality(plaintextResult5->GetCKKSPackedValue(), results5->GetCKKSPackedValue(), epsHigh,
18801886
failmsg + " EvalPoly for low-degree polynomial failed: " + buffer5.str());
1887+
1888+
Ciphertext<Element> cResult6 = cc->EvalPolyLinear(ciphertext1, coefficients6);
1889+
Plaintext results6;
1890+
cc->Decrypt(kp.secretKey, cResult6, &results6);
1891+
results6->SetLength(encodedLength);
1892+
std::stringstream buffer6;
1893+
buffer6 << "should be: " << plaintextResult6->GetCKKSPackedValue()
1894+
<< " - we get: " << results6->GetCKKSPackedValue();
1895+
checkEquality(plaintextResult6->GetCKKSPackedValue(), results6->GetCKKSPackedValue(), epsHigh,
1896+
failmsg + " EvalPoly for linear polynomial failed: " + buffer6.str());
18811897
}
18821898
catch (std::exception& e) {
18831899
std::cerr << "Exception thrown from " << __func__ << "(): " << e.what() << std::endl;

0 commit comments

Comments
 (0)