CKKS-RNS Automorphism before key-switching leads to excessive noise and decryption failure

Hello, I am learning the rotation operation EvalRotate in the library. To understand the steps in this operation in more detail, I follow the steps of EvalRotate in my main function.

EvalRotate performs automorphism after key-switching.
[As shown in step 8. In the following code]

I try to perform the automorphism operation before performing the key-switching operation
[As shown in step 4. In the following code],
and I got the message:
[terminate called after throwing an instance of ‘lbcrypto::math_error’ what(): E:/openfhe-development/src/pke/lib/encoding/ckkspackedencoding.cpp:535 The decryption failed because the approximation error is too high. Check the parameters.]

Automorphism is the mapping of values in the slot, so why would automorphism after key switching cause too much noise? Does performing automorphism before and after key-switching result in different noise scales?

Here is the code:

code

#include “openfhe.h”

using namespace lbcrypto;
std::shared_ptr<std::vector> EvalFastKeySwitchCoreExt(
const std::shared_ptr<std::vector> digits,
const EvalKey evalKey,
const std::shared_ptrlbcrypto::M4DCRTParams paramsQl);

int main() {
uint32_t batchSize = 8;

CCParams<CryptoContextCKKSRNS> parameters;

parameters.SetSecurityLevel(HEStd_NotSet);
parameters.SetBatchSize(batchSize);
parameters.SetRingDim(batchSize*2);//(1<<10);
parameters.SetMultiplicativeDepth(5);
parameters.SetFirstModSize(40);
parameters.SetScalingModSize(35);

uint32_t dnum = 2;
parameters.SetKeySwitchTechnique(HYBRID);
parameters.SetNumLargeDigits(dnum);
parameters.SetScalingTechnique(FLEXIBLEAUTO); 

CryptoContext<DCRTPoly> cc = GenCryptoContext(parameters);

cc->Enable(PKE);
cc->Enable(KEYSWITCH);
cc->Enable(LEVELEDSHE);
std::cout << "CKKS scheme is using ring dimension " << cc->GetRingDimension() << std::endl << std::endl;
std::cout << "CKKS scheme is using modulus " << cc->GetModulus() << std::endl << std::endl;

const auto cryptoParamsCKKS = std::dynamic_pointer_cast<CryptoParametersCKKSRNS>(cc->GetCryptoParameters());

auto paramsQ = cc->GetElementParams()->GetParams();
std::cout << "\nModuli in Q:" << std::endl;
for (uint32_t i = 0; i < paramsQ.size(); i++) {
  // q0 is a bit larger because its default size is 60 bits.
  // One can change this by supplying the firstModSize argument
  // in genCryptoContextCKKS.
  std::cout << "q" << i << ": " << paramsQ[i]->GetModulus() << std::endl;
}
auto paramsQP = cryptoParamsCKKS->GetParamsQP();
std::cout << "Moduli in P: " << std::endl;
BigInteger P = BigInteger(1);
for (uint32_t i = 0; i < paramsQP->GetParams().size(); i++) {
  if (i >= paramsQ.size()) {
    P = P * BigInteger(paramsQP->GetParams()[i]->GetModulus());
    std::cout << "p" << i - paramsQ.size() << ": "
              << paramsQP->GetParams()[i]->GetModulus() << std::endl;
  }
}
auto QBitLength = cc->GetModulus().GetLengthForBase(2);
auto PBitLength = P.GetLengthForBase(2);
std::cout << "\nQ = " << cc->GetModulus() << " (bit length: " << QBitLength
          << ")" << std::endl;
std::cout << "P = " << P << " (bit length: " << PBitLength << ")"
          << std::endl;
std::cout << "Total bit-length of ciphertext modulus: "
          << QBitLength + PBitLength << std::endl;
std::cout << "Given this ciphertext modulus, a ring dimension of "
          << cc->GetRingDimension() << " gives us 128-bit security."
          << std::endl;

auto keys = cc->KeyGen();

cc->EvalMultKeyGen(keys.secretKey);

cc->EvalRotateKeyGen(keys.secretKey, {1, -2});

// Encoding and encryption of inputs

// Inputs
std::vector<double> x1 = {0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0, 5.0};
std::vector<double> x2 = {5.0, 4.0, 3.0, 2.0, 1.0, 0.75, 0.5, 0.25};

// Encoding as plaintexts
Plaintext ptxt1 = cc->MakeCKKSPackedPlaintext(x1);
Plaintext ptxt2 = cc->MakeCKKSPackedPlaintext(x2);

std::cout << "Input x1: " << ptxt1 << std::endl;
std::cout << "Input x2: " << ptxt2 << std::endl;

// Encrypt the encoded vectors
auto c1 = cc->Encrypt(keys.publicKey, ptxt1);
auto c2 = cc->Encrypt(keys.publicKey, ptxt2);

// Homomorphic rotations
auto cRot1 = cc->EvalRotate(c1, 1);

auto ctest = c1; // ciphertext
int32_t index = -2;
// 1. EvalRotate(ConstCiphertext<Element> ciphertext, int32_t index)
    auto evalKeyMap = CryptoContextImpl<DCRTPoly>::GetEvalAutomorphismKeyMap(ctest->GetKeyTag());
// 2. return GetScheme()->EvalAtIndex(ciphertext ctest, int32_t index, const std::map<usint, EvalKey<Element>>& evalKeyMap);
    usint M = ctest->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder();
    uint32_t autoIndex = FindAutomorphismIndex2nComplex(index, M);
// 3. EvalAutomorphism(ciphertext, autoIndex, evalKeyMap);
    auto evalKeyIterator = evalKeyMap.find(autoIndex);
    const std::vector<DCRTPoly>& cv = ctest->GetElements();
    usint N = cv[0].GetRingDimension();
    std::vector<usint> vec(N);
    PrecomputeAutoMap(N, autoIndex, &vec);
    auto algo = ctest->GetCryptoContext()->GetScheme();  // cc->GetScheme()
    Ciphertext<DCRTPoly> c1_clone = ctest->Clone();
    auto evalKey = evalKeyIterator->second;
    std::vector<DCRTPoly>& cv_clone = c1_clone->GetElements();
// 4. AutomorphismTransform before keyswitching
    cv_clone[0] = cv_clone[0].AutomorphismTransform(autoIndex, vec);
    cv_clone[1] = cv_clone[1].AutomorphismTransform(autoIndex, vec);
            /* std::shared_ptr<std::vector<DCRTPoly>> ba = (cv_clone.size() == 2) ? KeySwitchCore(cv_clone[1], evalKeyIterator->second) : KeySwitchCore(cv_clone[2], evalKeyIterator->second);
               std::shared_ptr<std::vector<DCRTPoly>> KeySwitchHYBRID::KeySwitchCore(const DCRTPoly& a, const EvalKey<DCRTPoly> evalKey)
            */
// 5. EvalKeySwitchPrecomputeCore(cv_clone[1], evalKeyIterator->second->GetCryptoParameters())
        const auto cryptoParams = cryptoParamsCKKS;//std::dynamic_pointer_cast<CryptoParametersRNS>(evalKeyIterator->second->GetCryptoParameters());
        const auto paramsQl  = cv_clone[1].GetParams(); //const std::shared_ptr<ParmType> paramsQl  = cv_clone[1].GetParams();
        const auto paramsP   = cryptoParams->GetParamsP(); //const std::shared_ptr<ParmType> paramsP   = cryptoParams->GetParamsP();
        const auto paramsQlP = cv_clone[1].GetExtendedCRTBasis(paramsP); //const std::shared_ptr<ParmType> paramsQlP = cv_clone[1].GetExtendedCRTBasis(paramsP);
        size_t sizeQl  = paramsQl->GetParams().size();
        size_t sizeP   = paramsP->GetParams().size();
        size_t sizeQlP = sizeQl + sizeP;
        uint32_t alpha = cryptoParams->GetNumPerPartQ();   
        uint32_t numPartQl = ceil((static_cast<double>(sizeQl)) / alpha); 
        if (numPartQl > cryptoParams->GetNumberOfQPartitions())
            numPartQl = cryptoParams->GetNumberOfQPartitions();
        std::vector<DCRTPoly> partsCt(numPartQl);  
        for (uint32_t part = 0; part < numPartQl; part++) {
            if (part == numPartQl - 1) {            
                auto paramsPartQ = cryptoParams->GetParamsPartQ(part);  
        
                uint32_t sizePartQl = sizeQl - alpha * part;    
        
                std::vector<NativeInteger> moduli(sizePartQl);  
                std::vector<NativeInteger> roots(sizePartQl);
        
                for (uint32_t i = 0; i < sizePartQl; i++) {
                    moduli[i] = paramsPartQ->GetParams()[i]->GetModulus();
                    roots[i]  = paramsPartQ->GetParams()[i]->GetRootOfUnity();
                }
        
                auto params = DCRTPoly::Params(paramsPartQ->GetCyclotomicOrder(), moduli, roots);
        
                partsCt[part] = DCRTPoly(std::make_shared<lbcrypto::M4DCRTParams>(params), Format::EVALUATION, true); 
            }
            else {
                partsCt[part] = DCRTPoly(cryptoParams->GetParamsPartQ(part), Format::EVALUATION, true); 
            }
        
            usint sizePartQl   = partsCt[part].GetNumOfElements();  
            usint startPartIdx = alpha * part;  // split 
            for (uint32_t i = 0, idx = startPartIdx; i < sizePartQl; i++, idx++) {
                partsCt[part].SetElementAtIndex(i, cv_clone[1].GetElementAtIndex(idx));
            }
            std::cout << "Digit " << part << " has " << sizePartQl << " sub-moduli." << std::endl;
        }
        std::vector<DCRTPoly> partsCtCompl(numPartQl);  
        std::vector<DCRTPoly> partsCtExt(numPartQl);    
        for (uint32_t part = 0; part < numPartQl; part++) {
            auto partCtClone = partsCt[part].Clone();   
            partCtClone.SetFormat(Format::COEFFICIENT);

            uint32_t sizePartQl = partsCt[part].GetNumOfElements();
            partsCtCompl[part]  = partCtClone.ApproxSwitchCRTBasis( 
            cryptoParams->GetParamsPartQ(part),
            cryptoParams->GetParamsComplPartQ(sizeQl - 1, part),    
            cryptoParams->GetPartQlHatInvModq(part, sizePartQl - 1), 
            cryptoParams->GetPartQlHatInvModqPrecon(part, sizePartQl - 1),  
            cryptoParams->GetPartQlHatModp(sizeQl - 1, part),   
            cryptoParams->GetmodComplPartqBarrettMu(sizeQl - 1, part) 
            );
            partsCtCompl[part].SetFormat(Format::EVALUATION);
            partsCtExt[part] = DCRTPoly(paramsQlP, Format::EVALUATION, true);
            usint startPartIdx = alpha * part;
            usint endPartIdx   = startPartIdx + sizePartQl;
            for (usint i = 0; i < startPartIdx; i++) {
                partsCtExt[part].SetElementAtIndex(i, partsCtCompl[part].GetElementAtIndex(i));
            }
            for (usint i = startPartIdx, idx = 0; i < endPartIdx; i++, idx++) {
                partsCtExt[part].SetElementAtIndex(i, partsCt[part].GetElementAtIndex(idx));
            }
            for (usint i = endPartIdx; i < sizeQlP; ++i) {
                partsCtExt[part].SetElementAtIndex(i, partsCtCompl[part].GetElementAtIndex(i - sizePartQl));
            }

            auto otherdigit = cryptoParams->GetParamsComplPartQ(sizeQl - 1, part);
            auto qhatinv = cryptoParams->GetPartQlHatInvModq(part, sizePartQl - 1);
            auto qhatmodp = cryptoParams->GetPartQlHatModp(sizeQl - 1, part);
            std::cout << "digit " << part << " 's qhatinv"<< std::endl;
            for (usint i=0; i<sizePartQl; i++){
                std::cout << "-------------- Start of qhatinv --------------" << std::endl;
                std::cout << "qhatinv(" << i << ") = " << qhatinv[i] << std::endl;
                std::cout << "qhat(" << i << ") mod other digits' or P's sub moduli : " << std::endl;
                for (usint j=0; j < otherdigit->GetParams().size(); j++){
                    std::cout << qhatmodp[i][j] << std::endl;
                }
                std::cout << "-------------- End of qhatmodp --------------" << std::endl;
            }
        }
// 6. EvalFastKeySwitchCore(std::make_shared<std::vector<DCRTPoly>>(std::move(partsCtExt)), evalKeyIterator->second, cv_clone[1].GetParams());
    // 6-1. EvalFastKeySwitchCoreExt(digits, evalKey, paramsQl); // Responsible for Inner Product
    std::shared_ptr<std::vector<DCRTPoly>> cTilda = EvalFastKeySwitchCoreExt(std::make_shared<std::vector<DCRTPoly>>(std::move(partsCtExt)), evalKey, cv_clone[1].GetParams());
    // 6-2. ModDown
    DCRTPoly ct0 = (*cTilda)[0].ApproxModDown(  paramsQl, 
                                                cryptoParams->GetParamsP(), 
                                                cryptoParams->GetPInvModq(),
                                                cryptoParams->GetPInvModqPrecon(), 
                                                cryptoParams->GetPHatInvModp(),
                                                cryptoParams->GetPHatInvModpPrecon(), 
                                                cryptoParams->GetPHatModq(),
                                                cryptoParams->GetModqBarrettMu(), 
                                                cryptoParams->GettInvModp(),
                                                cryptoParams->GettInvModpPrecon(), 
                                                0, 
                                                cryptoParams->GettModqPrecon()
                                            );
                    
    DCRTPoly ct1 = (*cTilda)[1].ApproxModDown(  paramsQl, 
                                                cryptoParams->GetParamsP(), 
                                                cryptoParams->GetPInvModq(),
                                                cryptoParams->GetPInvModqPrecon(),      // 没用到
                                                cryptoParams->GetPHatInvModp(),
                                                cryptoParams->GetPHatInvModpPrecon(),   // 没用到
                                                cryptoParams->GetPHatModq(),
                                                cryptoParams->GetModqBarrettMu(), 
                                                cryptoParams->GettInvModp(),            // 没用到
                                                cryptoParams->GettInvModpPrecon(),      // 没用到
                                                0,                                      // =0, 用不上
                                                cryptoParams->GettModqPrecon()          // 没用到
                                            );
// 7. c0+c0', c1+c1'
    std::shared_ptr<std::vector<DCRTPoly>> ba = std::make_shared<std::vector<DCRTPoly>>(std::initializer_list<DCRTPoly>{std::move(ct0), std::move(ct1)});
    cv_clone[0].SetFormat((*ba)[0].GetFormat());
    cv_clone[0] += (*ba)[0];
    cv_clone[1].SetFormat((*ba)[1].GetFormat());
    if (cv_clone.size() > 2) {
        cv_clone[1] += (*ba)[1];
    }
    else {
        cv_clone[1] = (*ba)[1];
    }
    cv_clone.resize(2);

// 8. AutomorphismTransform after keyswitching
    // std::vector<DCRTPoly>& rcv = c1_clone->GetElements();
    //  rcv[0] = rcv[0].AutomorphismTransform(autoIndex, vec);
    //  rcv[1] = rcv[1].AutomorphismTransform(autoIndex, vec);


Plaintext result;
std::cout.precision(8);
cc->Decrypt(keys.secretKey, cRot1, &result);
result->SetLength(batchSize);
std::cout << std::endl << "In rotations, very small outputs (~10^-10 here) correspond to 0's:" << std::endl;
std::cout << "x1 rotate by 1 = " << result << std::endl;

cc->Decrypt(keys.secretKey, c1_clone, &result);
result->SetLength(batchSize);
std::cout << std::endl << "My Result is " << std::endl;
std::cout << "x1 rotate by 1 = " << result << std::endl;


return 0;

}

std::shared_ptr<std::vector> EvalFastKeySwitchCoreExt(
const std::shared_ptr<std::vector> digits,
const EvalKey evalKey,
const std::shared_ptrlbcrypto::M4DCRTParams paramsQl){

const auto cryptoParams = std::dynamic_pointer_cast<CryptoParametersRNS>(evalKey->GetCryptoParameters());   
// original params. No matter which level. evalKey always carry the original params.

const std::vector<DCRTPoly>& bv = evalKey->GetBVector();    // keyB
const std::vector<DCRTPoly>& av = evalKey->GetAVector();    // keyA

const std::shared_ptr<lbcrypto::M4DCRTParams> paramsP   = cryptoParams->GetParamsP();
// gets from original params. Since P will never change.
const std::shared_ptr<lbcrypto::M4DCRTParams> paramsQlP = (*digits)[0].GetParams();   
// get current params Ql+P. *digits[0] is digit0 Poly (a DCRTPoly). inside of it, each element is a respoly (q_0...q_l,p_0,...p_{alpha-1}). 

size_t sizeQl  = paramsQl->GetParams().size();
// get current params Ql'size(). Get (l+1)

size_t sizeQlP = paramsQlP->GetParams().size();
// Get (l+1)+alpha

size_t sizeQ   = cryptoParams->GetElementParams()->GetParams().size();
// Get original size QL = L+1? Used to locate P's index.

DCRTPoly cTilda0(paramsQlP, Format::EVALUATION, true);  // Inner Product = evkA*d2
DCRTPoly cTilda1(paramsQlP, Format::EVALUATION, true);  // Inner Product = evkB*d2

for (uint32_t j = 0; j < digits->size(); j++) { // j = 0 -> (beta-1). representing digit0 to digit(beta-1).
    const DCRTPoly& cj = (*digits)[j];  // cj = (*digits)[j] is digitj Poly. d_{2,j}^{(q_0~q_l, p_0~p_{alpha-1})}
    const DCRTPoly& bj = bv[j];         // bj = evkb_{j}^{(q_0~q_L, p_0~p_{alpha-1})}
    const DCRTPoly& aj = av[j];         // aj = evka_{j}^{(q_0~q_L, p_0~p_{alpha-1})}

    for (usint i = 0; i < sizeQl; i++) {    
        const auto& cji = cj.GetElementAtIndex(i);  // cji = d_{2,j}^{(q_i)}
        const auto& aji = aj.GetElementAtIndex(i);  // aji = evka_{j}^{(q_i)}
        const auto& bji = bj.GetElementAtIndex(i);  // bji = evkb_{j}^{(q_i)}

        cTilda0.SetElementAtIndex(i, cTilda0.GetElementAtIndex(i) + cji * bji); // ct0^{(q_i)} = sum_{j=0}^{beta-1} ( d_{2,j}^{(q_i)} * evkb_{j}^{(q_i)} ) mod q_i
        cTilda1.SetElementAtIndex(i, cTilda1.GetElementAtIndex(i) + cji * aji); // ct1^{(q_i)} = sum_{j=0}^{beta-1} ( d_{2,j}^{(q_i)} * evka_{j}^{(q_i)} ) mod q_i
    }   // for specific index j, now we get d_{2,j} * evka_j/evkb_j contribution to ct0 and ct1. @ q0~ql
    for (usint i = sizeQl, idx = sizeQ; i < sizeQlP; i++, idx++) {  // @ P's sub-moduli.
        const auto& cji = cj.GetElementAtIndex(i);  // cji = d_{2,j}^{ ( p_{i-sizeQl} ) }
        const auto& aji = aj.GetElementAtIndex(idx);// aji = evka_j^{ ( p_{idx - (L+1)} ) }
        const auto& bji = bj.GetElementAtIndex(idx);// bji = evkb_j^{ ( p_{idx - (L+1)} ) }

        cTilda0.SetElementAtIndex(i, cTilda0.GetElementAtIndex(i) + cji * bji);
        cTilda1.SetElementAtIndex(i, cTilda1.GetElementAtIndex(i) + cji * aji);
    }   // for specific index j, now we get d_{2,j} * evka_j/evkb_j contribution to ct0 and ct1. @ p0~p_{alpha-1}
}   // After all j, we get the Inner Product Results : ct0 and ct1 @ q_0~q_l and p_0~p_{alpha-1}.

return std::make_shared<std::vector<DCRTPoly>>(
    std::initializer_list<DCRTPoly>{std::move(cTilda0), std::move(cTilda1)});

}

Automorphism does not affect the norm of the noise (it stays the same), but it changes the values (achieving the desired permutation effect).

In OpenFHE, we use the so-called second “hoisting” (from Efficient Bootstrapping for Approximate Homomorphic Encryption with Non-Sparse Keys) We do the automorphism at the very end. This requires using an inverse automorphism during key generation. Hence, with EvalAutomorphism you should always remember that there is one automorphism that is applied to the ciphertext and there is another one applied during automorphism key generation. The math should work out the way where after both “automorphisms” (one directly applied to the ciphertexts and the other implicitly applied to the secret key) the correct result is produced. It not sufficient just to modify the ciphertext automorphism.

The error you are getting means the result is incorrect (you are not permuting back to the secret key after the complete rotation workflow).