Skip to content

Commit 085310c

Browse files
Moved the logic to deduplicate the indices in key generation operations for automorphism keys to a lower level
1 parent fcb4c26 commit 085310c

File tree

2 files changed

+31
-30
lines changed

2 files changed

+31
-30
lines changed

src/pke/include/cryptocontext.h

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -230,21 +230,6 @@ class CryptoContextImpl : public Serializable {
230230
value2);
231231
}
232232

233-
/**
234-
* @brief Gets indices that do not have automorphism keys for the given secret key tag in the key map
235-
*
236-
* @param keyTag secret key tag
237-
* @param indexList array of specific indices to check the key map against
238-
* @return indices that do not have automorphism keys associated with
239-
*/
240-
static std::set<uint32_t> GetEvalAutomorphismNoKeyIndices(const std::string& keyTag,
241-
const std::set<uint32_t>& indices) {
242-
std::set<uint32_t> existingIndices{CryptoContextImpl<Element>::GetExistingEvalAutomorphismKeyIndices(keyTag)};
243-
// if no index found for the given keyTag, then the entire set "indices" is returned
244-
return (existingIndices.empty()) ? indices :
245-
CryptoContextImpl<Element>::GetUniqueValues(existingIndices, indices);
246-
}
247-
248233
/**
249234
* @brief Gets automorphism keys for a specific secret key tag and an array of specific indices
250235
*
@@ -2250,13 +2235,7 @@ class CryptoContextImpl : public Serializable {
22502235
if (!indexList.size())
22512236
OPENFHE_THROW("Input index vector is empty");
22522237

2253-
// Do not generate duplicate keys that have been already generated and added to the static storage (map)
2254-
std::set<uint32_t> allIndices(indexList.begin(), indexList.end());
2255-
std::set<uint32_t> indicesToGenerate{
2256-
CryptoContextImpl<Element>::GetEvalAutomorphismNoKeyIndices(privateKey->GetKeyTag(), allIndices)};
2257-
2258-
std::vector<uint32_t> newIndices(indicesToGenerate.begin(), indicesToGenerate.end());
2259-
auto evalKeys = GetScheme()->EvalAutomorphismKeyGen(privateKey, newIndices);
2238+
auto evalKeys = GetScheme()->EvalAutomorphismKeyGen(privateKey, indexList);
22602239
CryptoContextImpl<Element>::InsertEvalAutomorphismKey(evalKeys, privateKey->GetKeyTag());
22612240

22622241
return evalKeys;
@@ -4015,6 +3994,21 @@ class CryptoContextImpl : public Serializable {
40153994
GetScheme()->SetSwkFC(FHEWtoCKKSswk);
40163995
}
40173996

3997+
/**
3998+
* @brief Gets indices that do not have automorphism keys for the given secret key tag in the key map
3999+
*
4000+
* @param keyTag secret key tag
4001+
* @param indexList array of specific indices to check the key map against
4002+
* @return indices that do not have automorphism keys associated with
4003+
*/
4004+
static std::set<uint32_t> GetEvalAutomorphismNoKeyIndices(const std::string& keyTag,
4005+
const std::set<uint32_t>& indices) {
4006+
std::set<uint32_t> existingIndices{CryptoContextImpl<Element>::GetExistingEvalAutomorphismKeyIndices(keyTag)};
4007+
// if no index found for the given keyTag, then the entire set "indices" is returned
4008+
return (existingIndices.empty()) ? indices :
4009+
CryptoContextImpl<Element>::GetUniqueValues(existingIndices, indices);
4010+
}
4011+
40184012
/**
40194013
* @brief Returns automorphism indices for all existing evaluation keys.
40204014
* @param keyTag Secret key tag.

src/pke/lib/schemebase/base-leveledshe.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,16 @@ void LeveledSHEBase<Element>::RelinearizeInPlace(Ciphertext<Element>& ciphertext
347347
template <class Element>
348348
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> LeveledSHEBase<Element>::EvalAutomorphismKeyGen(
349349
const PrivateKey<Element> privateKey, const std::vector<uint32_t>& indexList) const {
350+
351+
// Do not generate duplicate keys that have been already generated and added to the static storage (map)
352+
std::set<uint32_t> allIndices(indexList.begin(), indexList.end());
353+
std::set<uint32_t> indicesToGenerate{
354+
CryptoContextImpl<Element>::GetEvalAutomorphismNoKeyIndices(privateKey->GetKeyTag(), allIndices)};
355+
std::vector<uint32_t> newIndices(indicesToGenerate.begin(), indicesToGenerate.end());
356+
350357
// we already have checks on higher level?
351-
// auto it = std::find(indexList.begin(), indexList.end(), 2 * n - 1);
352-
// if (it != indexList.end())
358+
// auto it = std::find(newIndices.begin(), newIndices.end(), 2 * n - 1);
359+
// if (it != newIndices.end())
353360
// OPENFHE_THROW("conjugation is disabled");
354361

355362
const auto cc = privateKey->GetCryptoContext();
@@ -359,26 +366,26 @@ std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> LeveledSHEBase<Element>::E
359366
const uint32_t M = s.GetCyclotomicOrder();
360367

361368
// we already have checks on higher level?
362-
// if (indexList.size() > N - 1)
369+
// if (newIndices.size() > N - 1)
363370
// OPENFHE_THROW("size exceeds the ring dimension");
364371

365-
// create and initialize the key map (key is a value from indexList, EvalKey is nullptr). in this case
372+
// create and initialize the key map (key is a value from newIndices, EvalKey is nullptr). in this case
366373
// we should be able to assign values to the map without using "omp critical" as all evalKeys' elements would
367374
// have already been created
368375
auto evalKeys = std::make_shared<std::map<uint32_t, EvalKey<Element>>>();
369-
for (auto indx : indexList)
376+
for (auto indx : newIndices)
370377
(*evalKeys)[indx];
371378

372-
const uint32_t sz = indexList.size();
379+
const uint32_t sz = newIndices.size();
373380
#pragma omp parallel for
374381
for (uint32_t i = 0; i < sz; ++i) {
375-
auto index = NativeInteger(indexList[i]).ModInverse(M).ConvertToInt<uint32_t>();
382+
auto index = NativeInteger(newIndices[i]).ModInverse(M).ConvertToInt<uint32_t>();
376383
std::vector<uint32_t> vec(N);
377384
PrecomputeAutoMap(N, index, &vec);
378385

379386
auto privateKeyPermuted = std::make_shared<PrivateKeyImpl<Element>>(cc);
380387
privateKeyPermuted->SetPrivateElement(s.AutomorphismTransform(index, vec));
381-
(*evalKeys)[indexList[i]] = cc->GetScheme()->KeySwitchGen(privateKey, privateKeyPermuted);
388+
(*evalKeys)[newIndices[i]] = cc->GetScheme()->KeySwitchGen(privateKey, privateKeyPermuted);
382389
}
383390

384391
return evalKeys;

0 commit comments

Comments
 (0)