CKKS Bootstrapping fails after deserialization

Hi,

I am trying to perform bootstrapping in CKKS, but am having trouble getting it to work in combination with serialization / deserialization.
My understanding from the answer in https://openfhe.discourse.group/t/serialize-the-bootstrap-key/433/3 is that serializing the keys for bootstrapping should be possible by serializing the EvalMult and EvalAutomorphism keys, as shown in the linked example “simple-real-numbers-serial”.

However, using this example as a base, I added two calls to EvalBootstrap, one before serialization and one after. The first one runs fine, but the second one fails with the following error message:
terminate called after throwing an instance of 'lbcrypto::OpenFHEException' what(): /usr/local/include/openfhe/pke/schemebase/base-scheme.h:l.355:KeySwitchDown(): Input ciphertext is nullptr.

Could this be a bug in the library, or am I missing anything?

I am pasting the exact code I ran below. I kept most of the initial example and only added the bootstrapping setup and EvalBootstrap calls. Apologies that this is not exactly a minimal example, but I wanted to stay as close to the original serialization example code as possible.

Thank you very much for your help!

/*
  Real number serialization in a simple context. The goal of this is to show a simple setup for real number
  serialization before progressing into the next logical step - serialization and communication across
  2 separate entities
 */

#include <iomanip>
#include <tuple>
#include <unistd.h>

#include "openfhe.h"

// header files needed for serialization
#include "ciphertext-ser.h"
#include "cryptocontext-ser.h"
#include "key/key-ser.h"
#include "scheme/ckksrns/ckksrns-ser.h"

using namespace lbcrypto;

/////////////////////////////////////////////////////////////////
// NOTE:
// If running locally, you may want to replace the "hardcoded" DATAFOLDER with
// the DATAFOLDER location below which gets the current working directory
/////////////////////////////////////////////////////////////////
// char buff[1024];
// std::string DATAFOLDER = std::string(getcwd(buff, 1024));

// Save-Load locations for keys
const std::string DATAFOLDER = "demoData";
std::string ccLocation       = "/cryptocontext.txt";
std::string pubKeyLocation   = "/key_pub.txt";   // Pub key
std::string multKeyLocation  = "/key_mult.txt";  // relinearization key
std::string rotKeyLocation   = "/key_rot.txt";   // automorphism / rotation key

// Save-load locations for RAW ciphertexts
std::string cipherOneLocation = "/ciphertext1.txt";
std::string cipherTwoLocation = "/ciphertext2.txt";

// Save-load locations for evaluated ciphertexts
std::string cipherMultLocation   = "/ciphertextMult.txt";
std::string cipherAddLocation    = "/ciphertextAdd.txt";
std::string cipherRotLocation    = "/ciphertextRot.txt";
std::string cipherRotNegLocation = "/ciphertextRotNegLocation.txt";
std::string clientVectorLocation = "/ciphertextVectorFromClient.txt";

/**
 * Demarcate - Visual separator between the sections of code
 * @param msg - string message that you want displayed between blocks of
 * characters
 */
void demarcate(const std::string& msg) {
    std::cout << std::setw(50) << std::setfill('*') << '\n' << std::endl;
    std::cout << msg << std::endl;
    std::cout << std::setw(50) << std::setfill('*') << '\n' << std::endl;
}

/**
 * serverSetupAndWrite
 *  - simulates a server at startup where we generate a cryptocontext and keys.
 *  - then, we generate some data (akin to loading raw data on an enclave)
 * before encrypting the data
 * @param multDepth - multiplication depth
 * @param scaleModSize - number of bits to use in the scale factor (not the
 * scale factor itself)
 * @param batchSize - batch size to use
 * @return Tuple<cryptoContext, keyPair>
 */
std::tuple<CryptoContext<DCRTPoly>, KeyPair<DCRTPoly>, int> serverSetupAndWrite(int multDepth, int scaleModSize,
                                                                                int batchSize) {
    CCParams<CryptoContextCKKSRNS> parameters;
    
    //SecretKeyDist secretKeyDist = UNIFORM_TERNARY;
    //parameters.SetSecretKeyDist(secretKeyDist);
    /*parameters.SetMultiplicativeDepth(multDepth);
    parameters.SetScalingModSize(scaleModSize);
    parameters.SetBatchSize(batchSize);
    parameters.SetSecurityLevel(HEStd_NotSet);
    parameters.SetRingDim(batchSize*2);

    CryptoContext<DCRTPoly> serverCC = GenCryptoContext(parameters);
    */
    

    std::cout << "Cryptocontext generated" << std::endl;

    /*std::vector<uint32_t> levelBudget = {4, 4};*/

    SecretKeyDist secretKeyDist = UNIFORM_TERNARY;
    parameters.SetSecretKeyDist(secretKeyDist);

    parameters.SetSecurityLevel(HEStd_NotSet);
    parameters.SetRingDim(1 << 12);

#if NATIVEINT == 128 && !defined(__EMSCRIPTEN__)
    ScalingTechnique rescaleTech = FIXEDAUTO;
    usint dcrtBits               = 78;
    usint firstMod               = 89;
#else
    ScalingTechnique rescaleTech = FLEXIBLEAUTO;
    usint dcrtBits               = 59;
    usint firstMod               = 60;
#endif

    parameters.SetScalingModSize(dcrtBits);
    parameters.SetScalingTechnique(rescaleTech);
    parameters.SetFirstModSize(firstMod);

    /*  A4) Multiplicative depth.
    * The goal of bootstrapping is to increase the number of available levels we have, or in other words,
    * to dynamically increase the multiplicative depth. However, the bootstrapping procedure itself
    * needs to consume a few levels to run. We compute the number of bootstrapping levels required
    * using GetBootstrapDepth, and add it to levelsAvailableAfterBootstrap to set our initial multiplicative
    * depth. We recommend using the input parameters below to get started.
    */
    std::vector<uint32_t> levelBudget = {4, 4};

    // Note that the actual number of levels avalailable after bootstrapping before next bootstrapping 
    // will be levelsAvailableAfterBootstrap - 1 because an additional level
    // is used for scaling the ciphertext before next bootstrapping (in 64-bit CKKS bootstrapping)
    uint32_t levelsAvailableAfterBootstrap = 10;
    usint depth = levelsAvailableAfterBootstrap + FHECKKSRNS::GetBootstrapDepth(levelBudget, secretKeyDist);
    parameters.SetMultiplicativeDepth(depth);


    CryptoContext<DCRTPoly> serverCC = GenCryptoContext(parameters);

    serverCC->Enable(PKE);
    serverCC->Enable(KEYSWITCH);
    serverCC->Enable(LEVELEDSHE);
    serverCC->Enable(ADVANCEDSHE);
    serverCC->Enable(FHE);
    

    serverCC->EvalBootstrapSetup(levelBudget);

    KeyPair<DCRTPoly> serverKP = serverCC->KeyGen();
    std::cout << "Keypair generated" << std::endl;

    serverCC->EvalMultKeyGen(serverKP.secretKey);
    std::cout << "Eval Mult Keys/ Relinearization keys have been generated" << std::endl;

    serverCC->EvalRotateKeyGen(serverKP.secretKey, {1, 2, -1, -2});
    std::cout << "Rotation keys generated" << std::endl;

    usint ringDim = serverCC->GetRingDimension();
    // This is the maximum number of slots that can be used for full packing.
    usint numSlots = ringDim / 2;
    std::cout << "CKKS scheme is using ring dimension " << ringDim << std::endl << std::endl;

    serverCC->EvalBootstrapKeyGen(serverKP.secretKey, numSlots);

    std::vector<std::complex<double>> vec1 = {1.0, 2.0, 3.0, 4.0};
    std::vector<std::complex<double>> vec2 = {12.5, 13.5, 14.5, 15.5};
    std::vector<std::complex<double>> vec3 = {10.5, 11.5, 12.5, 13.5};
    std::cout << "\nDisplaying first data vector: ";

    for (auto& v : vec1) {
        std::cout << v << ',';
    }

    std::cout << '\n' << std::endl;

    Plaintext serverP1 = serverCC->MakeCKKSPackedPlaintext(vec1);
    Plaintext serverP2 = serverCC->MakeCKKSPackedPlaintext(vec2);
    Plaintext serverP3 = serverCC->MakeCKKSPackedPlaintext(vec3);

    std::cout << "Plaintext version of first vector: " << serverP1 << std::endl;

    std::cout << "Plaintexts have been generated from complex-double vectors" << std::endl;

    auto serverC1 = serverCC->Encrypt(serverKP.publicKey, serverP1);
    auto serverC2 = serverCC->Encrypt(serverKP.publicKey, serverP2);
    auto serverC3 = serverCC->Encrypt(serverKP.publicKey, serverP3);

    serverC1 = serverCC->EvalBootstrap(serverC1);

    std::cout << "Ciphertexts have been generated from Plaintexts" << std::endl;

    /*
   * Part 2:
   * We serialize the following:
   *  Cryptocontext
   *  Public key
   *  relinearization (eval mult keys)
   *  rotation keys
   *  Some of the ciphertext
   *
   *  We serialize all of them to files
   */

    demarcate("Part 2: Data Serialization (server)");

    if (!Serial::SerializeToFile(DATAFOLDER + ccLocation, serverCC, SerType::BINARY)) {
        std::cerr << "Error writing serialization of the crypto context to "
                     "cryptocontext.txt"
                  << std::endl;
        std::exit(1);
    }

    std::cout << "Cryptocontext serialized" << std::endl;

    if (!Serial::SerializeToFile(DATAFOLDER + pubKeyLocation, serverKP.publicKey, SerType::BINARY)) {
        std::cerr << "Exception writing public key to pubkey.txt" << std::endl;
        std::exit(1);
    }
    std::cout << "Public key serialized" << std::endl;

    std::ofstream multKeyFile(DATAFOLDER + multKeyLocation, std::ios::out | std::ios::binary);
    if (multKeyFile.is_open()) {
        if (!serverCC->SerializeEvalMultKey(multKeyFile, SerType::BINARY)) {
            std::cerr << "Error writing eval mult keys" << std::endl;
            std::exit(1);
        }
        std::cout << "EvalMult/ relinearization keys have been serialized" << std::endl;
        multKeyFile.close();
    }
    else {
        std::cerr << "Error serializing EvalMult keys" << std::endl;
        std::exit(1);
    }

    std::ofstream rotationKeyFile(DATAFOLDER + rotKeyLocation, std::ios::out | std::ios::binary);
    if (rotationKeyFile.is_open()) {
        if (!serverCC->SerializeEvalAutomorphismKey(rotationKeyFile, SerType::BINARY)) {
            std::cerr << "Error writing rotation keys" << std::endl;
            std::exit(1);
        }
        std::cout << "Rotation keys have been serialized" << std::endl;
    }
    else {
        std::cerr << "Error serializing Rotation keys" << std::endl;
        std::exit(1);
    }

    if (!Serial::SerializeToFile(DATAFOLDER + cipherOneLocation, serverC1, SerType::BINARY)) {
        std::cerr << " Error writing ciphertext 1" << std::endl;
    }

    if (!Serial::SerializeToFile(DATAFOLDER + cipherTwoLocation, serverC2, SerType::BINARY)) {
        std::cerr << " Error writing ciphertext 2" << std::endl;
    }

    return std::make_tuple(serverCC, serverKP, vec1.size());
}

/**
 * clientProcess
 *  - deserialize data from a file which simulates receiving data from a server
 * after making a request
 *  - we then process the data by doing operations (multiplication, addition,
 * rotation, etc)
 *  - !! We also create an object and encrypt it in this function before sending
 * it off to the server to be decrypted
 */
void clientProcess() {
    CryptoContext<DCRTPoly> clientCC;
    clientCC->ClearEvalMultKeys();
    clientCC->ClearEvalAutomorphismKeys();
    lbcrypto::CryptoContextFactory<lbcrypto::DCRTPoly>::ReleaseAllContexts();
    if (!Serial::DeserializeFromFile(DATAFOLDER + ccLocation, clientCC, SerType::BINARY)) {
        std::cerr << "I cannot read serialized data from: " << DATAFOLDER << "/cryptocontext.txt" << std::endl;
        std::exit(1);
    }
    std::cout << "Client CC deserialized";

    KeyPair<DCRTPoly> clientKP;  // We do NOT have a secret key. The client
    // should not have access to this
    PublicKey<DCRTPoly> clientPublicKey;
    if (!Serial::DeserializeFromFile(DATAFOLDER + pubKeyLocation, clientPublicKey, SerType::BINARY)) {
        std::cerr << "I cannot read serialized data from: " << DATAFOLDER << "/cryptocontext.txt" << std::endl;
        std::exit(1);
    }
    std::cout << "Client KP deserialized" << '\n' << std::endl;

    std::ifstream multKeyIStream(DATAFOLDER + multKeyLocation, std::ios::in | std::ios::binary);
    if (!multKeyIStream.is_open()) {
        std::cerr << "Cannot read serialization from " << DATAFOLDER + multKeyLocation << std::endl;
        std::exit(1);
    }
    if (!clientCC->DeserializeEvalMultKey(multKeyIStream, SerType::BINARY)) {
        std::cerr << "Could not deserialize eval mult key file" << std::endl;
        std::exit(1);
    }

    std::cout << "Deserialized eval mult keys" << '\n' << std::endl;
    std::ifstream rotKeyIStream(DATAFOLDER + rotKeyLocation, std::ios::in | std::ios::binary);
    if (!rotKeyIStream.is_open()) {
        std::cerr << "Cannot read serialization from " << DATAFOLDER + multKeyLocation << std::endl;
        std::exit(1);
    }
    if (!clientCC->DeserializeEvalAutomorphismKey(rotKeyIStream, SerType::BINARY)) {
        std::cerr << "Could not deserialize eval rot key file" << std::endl;
        std::exit(1);
    }

    Ciphertext<DCRTPoly> clientC1;
    Ciphertext<DCRTPoly> clientC2;
    if (!Serial::DeserializeFromFile(DATAFOLDER + cipherOneLocation, clientC1, SerType::BINARY)) {
        std::cerr << "Cannot read serialization from " << DATAFOLDER + cipherOneLocation << std::endl;
        std::exit(1);
    }
    std::cout << "Deserialized ciphertext1" << '\n' << std::endl;

    if (!Serial::DeserializeFromFile(DATAFOLDER + cipherTwoLocation, clientC2, SerType::BINARY)) {
        std::cerr << "Cannot read serialization from " << DATAFOLDER + cipherTwoLocation << std::endl;
        std::exit(1);
    }

    std::cout << "Deserialized ciphertext1" << '\n' << std::endl;

    clientC1 = clientCC->EvalBootstrap(clientC1);

    auto clientCiphertextMult   = clientCC->EvalMult(clientC1, clientC2);
    auto clientCiphertextAdd    = clientCC->EvalAdd(clientC1, clientC2);
    auto clientCiphertextRot    = clientCC->EvalRotate(clientC1, 1);
    auto clientCiphertextRotNeg = clientCC->EvalRotate(clientC1, -1);

    // Now, we want to simulate a client who is encrypting data for the server to
    // decrypt. E.g weights of a machine learning algorithm
    demarcate("Part 3.5: Client Serialization of data that has been operated on");

    std::vector<std::complex<double>> clientVector1 = {1.0, 2.0, 3.0, 4.0};
    auto clientPlaintext1                           = clientCC->MakeCKKSPackedPlaintext(clientVector1);
    auto clientInitiatedEncryption                  = clientCC->Encrypt(clientPublicKey, clientPlaintext1);
    Serial::SerializeToFile(DATAFOLDER + cipherMultLocation, clientCiphertextMult, SerType::BINARY);
    Serial::SerializeToFile(DATAFOLDER + cipherAddLocation, clientCiphertextAdd, SerType::BINARY);
    Serial::SerializeToFile(DATAFOLDER + cipherRotLocation, clientCiphertextRot, SerType::BINARY);
    Serial::SerializeToFile(DATAFOLDER + cipherRotNegLocation, clientCiphertextRotNeg, SerType::BINARY);
    Serial::SerializeToFile(DATAFOLDER + clientVectorLocation, clientInitiatedEncryption, SerType::BINARY);

    std::cout << "Serialized all ciphertexts from client" << '\n' << std::endl;
}

/**
 * serverVerification
 *  - deserialize data from the client.
 *  - Verify that the results are as we expect
 * @param cc cryptocontext that was previously generated
 * @param kp keypair that was previously generated
 * @param vectorSize vector size of the vectors supplied
 * @return
 *  5-tuple of the plaintexts of various operations
 */

std::tuple<Plaintext, Plaintext, Plaintext, Plaintext, Plaintext> serverVerification(CryptoContext<DCRTPoly>& cc,
                                                                                     KeyPair<DCRTPoly>& kp,
                                                                                     int vectorSize) {
    Ciphertext<DCRTPoly> serverCiphertextFromClient_Mult;
    Ciphertext<DCRTPoly> serverCiphertextFromClient_Add;
    Ciphertext<DCRTPoly> serverCiphertextFromClient_Rot;
    Ciphertext<DCRTPoly> serverCiphertextFromClient_RogNeg;
    Ciphertext<DCRTPoly> serverCiphertextFromClient_Vec;

    Serial::DeserializeFromFile(DATAFOLDER + cipherMultLocation, serverCiphertextFromClient_Mult, SerType::BINARY);
    Serial::DeserializeFromFile(DATAFOLDER + cipherAddLocation, serverCiphertextFromClient_Add, SerType::BINARY);
    Serial::DeserializeFromFile(DATAFOLDER + cipherRotLocation, serverCiphertextFromClient_Rot, SerType::BINARY);
    Serial::DeserializeFromFile(DATAFOLDER + cipherRotNegLocation, serverCiphertextFromClient_RogNeg, SerType::BINARY);
    Serial::DeserializeFromFile(DATAFOLDER + clientVectorLocation, serverCiphertextFromClient_Vec, SerType::BINARY);
    std::cout << "Deserialized all data from client on server" << '\n' << std::endl;

    demarcate("Part 5: Correctness verification");

    Plaintext serverPlaintextFromClient_Mult;
    Plaintext serverPlaintextFromClient_Add;
    Plaintext serverPlaintextFromClient_Rot;
    Plaintext serverPlaintextFromClient_RotNeg;
    Plaintext serverPlaintextFromClient_Vec;

    cc->Decrypt(kp.secretKey, serverCiphertextFromClient_Mult, &serverPlaintextFromClient_Mult);
    cc->Decrypt(kp.secretKey, serverCiphertextFromClient_Add, &serverPlaintextFromClient_Add);
    cc->Decrypt(kp.secretKey, serverCiphertextFromClient_Rot, &serverPlaintextFromClient_Rot);
    cc->Decrypt(kp.secretKey, serverCiphertextFromClient_RogNeg, &serverPlaintextFromClient_RotNeg);
    cc->Decrypt(kp.secretKey, serverCiphertextFromClient_Vec, &serverPlaintextFromClient_Vec);

    serverPlaintextFromClient_Mult->SetLength(vectorSize);
    serverPlaintextFromClient_Add->SetLength(vectorSize);
    serverPlaintextFromClient_Vec->SetLength(vectorSize);
    serverPlaintextFromClient_Rot->SetLength(vectorSize + 1);
    serverPlaintextFromClient_RotNeg->SetLength(vectorSize + 1);

    return std::make_tuple(serverPlaintextFromClient_Mult, serverPlaintextFromClient_Add, serverPlaintextFromClient_Vec,
                           serverPlaintextFromClient_Rot, serverPlaintextFromClient_RotNeg);
}
int main() {
    std::cout << "This program requres the subdirectory `" << DATAFOLDER << "' to exist, otherwise you will get "
              << "an error writing serializations." << std::endl;

    // Set main params
    const int multDepth    = 5;
    const int scaleModSize = 40;
    const usint batchSize  = 32;

    const int cryptoContextIdx = 0;
    const int keyPairIdx       = 1;
    const int vectorSizeIdx    = 2;

    const int cipherMultResIdx   = 0;
    const int cipherAddResIdx    = 1;
    const int cipherVecResIdx    = 2;
    const int cipherRotResIdx    = 3;
    const int cipherRotNegResIdx = 4;

    demarcate(
        "Part 1: Cryptocontext generation, key generation, data encryption "
        "(server)");

    auto tupleCryptoContext_KeyPair = serverSetupAndWrite(multDepth, scaleModSize, batchSize);
    auto cc                         = std::get<cryptoContextIdx>(tupleCryptoContext_KeyPair);
    auto kp                         = std::get<keyPairIdx>(tupleCryptoContext_KeyPair);
    int vectorSize                  = std::get<vectorSizeIdx>(tupleCryptoContext_KeyPair);

    demarcate("Part 3: Client deserialize all data");
    clientProcess();

    demarcate("Part 4: Server deserialization of data from client. ");

    auto tupleRes  = serverVerification(cc, kp, vectorSize);
    auto multRes   = std::get<cipherMultResIdx>(tupleRes);
    auto addRes    = std::get<cipherAddResIdx>(tupleRes);
    auto vecRes    = std::get<cipherVecResIdx>(tupleRes);
    auto rotRes    = std::get<cipherRotResIdx>(tupleRes);
    auto rotNegRes = std::get<cipherRotNegResIdx>(tupleRes);

    // vec1: {1,2,3,4}
    // vec2: {12.5, 13.5, 14.5, 15.5}

    std::cout << multRes << std::endl;  // EXPECT: 12.5, 27.0, 43.5, 62
    std::cout << addRes << std::endl;   // EXPECT: 13.5, 15.5, 17.5, 19.5
    std::cout << vecRes << std::endl;   // EXPECT:  {1,2,3,4}

    std::cout << "Displaying 5 elements of a 4-element vector to illustrate rotation" << '\n';
    std::cout << rotRes << std::endl;     // EXPECT: {2, 3, 4, noise, noise}
    std::cout << rotNegRes << std::endl;  // EXPECT: {noise, 1, 2, 3, 4}
}

Hello, you need to run EvalBootstrapSetup() before EvalBootstrap(), so inplace

clientC1 = clientCC->EvalBootstrap(clientC1);

with

std::vector<uint32_t> levelBudget = {4, 4};
clientCC->EvalBootstrapSetup(levelBudget);
clientC1 = clientCC->EvalBootstrap(clientC1);

will solve your problem.