Skip to content

Commit b9bdf4d

Browse files
committed
Fix evaluation of linear functions with EvalPoly
The `internalEvalPolyLinearWithPrecomp` contained an off-by-one error that caused it to throw an exception when the polynomial to evaluate was a linear function. This patch fixes that mistake, and adds a test that evaluates a linear polynomial on a CKKS ciphertext.
1 parent aa39198 commit b9bdf4d

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

src/pke/lib/scheme/ckksrns/ckksrns-advancedshe.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,9 +331,10 @@ std::shared_ptr<seriesPowers<DCRTPoly>> AdvancedSHECKKSRNS::EvalPowers(
331331
template <typename VectorDataType>
332332
static inline Ciphertext<DCRTPoly> internalEvalPolyLinearWithPrecomp(std::vector<Ciphertext<DCRTPoly>>& powers,
333333
const std::vector<VectorDataType>& coefficients) {
334+
if (coefficients.size() < 2)
335+
OPENFHE_THROW("EvalPolyLinear: The coefficients vector should contain at least 2 elements");
336+
334337
uint32_t k = coefficients.size() - 1;
335-
if (k <= 1)
336-
OPENFHE_THROW("The coefficients vector should contain at least 2 elements");
337338

338339
if (!IsNotEqualZero(coefficients[k]))
339340
OPENFHE_THROW("EvalPolyLinear: The highest-order coefficient cannot be set to 0.");

src/pke/unittest/utckksrns/UnitTestCKKSrns.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,6 +1801,9 @@ class UTCKKSRNS : public ::testing::TestWithParam<TEST_CASE_UTCKKSRNS> {
18011801
// x + x^2 - x^3
18021802
// low-degree function to check linear implementation
18031803
std::vector<double> coefficients5{0, 1, 1, -1};
1804+
// 1 + 2x
1805+
// linear function to check linear implementation
1806+
std::vector<double> coefficients6{1.0, 2.0};
18041807

18051808
Plaintext plaintext1 = cc->MakeCKKSPackedPlaintext(input);
18061809

@@ -1819,6 +1822,9 @@ class UTCKKSRNS : public ::testing::TestWithParam<TEST_CASE_UTCKKSRNS> {
18191822
std::vector<std::complex<double>> output5{0.625, 0.847, 0.9809999999, 0.995125, 0.990543};
18201823
Plaintext plaintextResult5 = cc->MakeCKKSPackedPlaintext(output5);
18211824

1825+
std::vector<std::complex<double>> output6{2.0, 2.4, 2.8, 2.9, 2.86};
1826+
Plaintext plaintextResult6 = cc->MakeCKKSPackedPlaintext(output6);
1827+
18221828
// Generate encryption keys
18231829
KeyPair<Element> kp = cc->KeyGen();
18241830
// Generate multiplication keys
@@ -1878,6 +1884,16 @@ class UTCKKSRNS : public ::testing::TestWithParam<TEST_CASE_UTCKKSRNS> {
18781884
<< " - we get: " << results5->GetCKKSPackedValue();
18791885
checkEquality(plaintextResult5->GetCKKSPackedValue(), results5->GetCKKSPackedValue(), epsHigh,
18801886
failmsg + " EvalPoly for low-degree polynomial failed: " + buffer5.str());
1887+
1888+
Ciphertext<Element> cResult6 = cc->EvalPolyLinear(ciphertext1, coefficients6);
1889+
Plaintext results6;
1890+
cc->Decrypt(kp.secretKey, cResult6, &results6);
1891+
results6->SetLength(encodedLength);
1892+
std::stringstream buffer6;
1893+
buffer6 << "should be: " << plaintextResult6->GetCKKSPackedValue()
1894+
<< " - we get: " << results6->GetCKKSPackedValue();
1895+
checkEquality(plaintextResult6->GetCKKSPackedValue(), results6->GetCKKSPackedValue(), epsHigh,
1896+
failmsg + " EvalPoly for linear polynomial failed: " + buffer6.str());
18811897
}
18821898
catch (std::exception& e) {
18831899
std::cerr << "Exception thrown from " << __func__ << "(): " << e.what() << std::endl;

0 commit comments

Comments
 (0)