Hello, I am working on an application in a client-server setup where the two parties need to exchange a large number of Ciphertexts. Before sending them over the network the client needs to perform these operations for each data item: Pack, Encrypt, Serialize (to bytes), Archive. When the server receives the archive with the Ciphertexts over the network, it needs to perform the following operations for each Ciphertext: Extract, Deserialize, Decrypt and Unpack.
After timing each of these operations, I found out that on the first part, the most time-consuming operation is the Encrypt operation but in the second part, it’s the Serialize one.
Do you have any idea where does this come from for the Serialize operation? Is it inherent to the Serial class (where probably the Ciphertext object inherits the serialize method) or is this an implementation from the OpenFHE team. Do you see any ways of speeding-up this operation? This becomes the bottleneck in performance if you are handling a large number of Ciphertexts.
Below you can find a Minimal Reproducible Example (MRE) using openfhe-python v1.3.0 that performs these operations together with the execution times for 10000 Ciphertexts:
Output for 10000 Ciphertexts:
Setting up HE...
Plaintext Modulus: 2148794369 (32 bits)
Ciphertext modulus bitsize (log2 q) = 240
Ring Dimension (m) = 16384
Packing, Encrypting, Serializing and Archiving 10000 Ciphertexts...
ciphertext = 9900/10000 (99.00 %)
Packing took 0.00 seconds
Encrypting took 63.58 seconds
Serializing took 10.39 seconds
Archiving took 9.24 seconds
Total time took 89.30 seconds
Extracting, Deserializing, Decrypting and Unpacking 10000 Ciphertexts...
ciphertext = 9900/10000 (99.00 %)
Extracting took 3.64 seconds
Deserializing took 70.61 seconds
Decrypting took 14.00 seconds
Unpacking took 5.26 seconds
Total time took 93.57 seconds
Code in openfhe-python:
import timeit
import zipfile
import openfhe as fhe
import math
FILE_TYPE = fhe.BINARY
COMPRESSION_METHOD = zipfile.ZIP_STORED
print(f"Setting up HE...")
# Set the parameters
parameters = fhe.CCParamsBFVRNS()
plaintext_modulus = 2148794369
print(f"\tPlaintext Modulus: {plaintext_modulus} ({math.ceil(math.log2(plaintext_modulus))} bits)")
parameters.SetPlaintextModulus(plaintext_modulus)
parameters.SetMultiplicativeDepth(3)
parameters.SetSecurityLevel(fhe.HEStd_128_classic)
crypto_context = fhe.GenCryptoContext(parameters)
crypto_context.Enable(fhe.PKESchemeFeature.PKE)
crypto_context.Enable(fhe.PKESchemeFeature.KEYSWITCH)
crypto_context.Enable(fhe.PKESchemeFeature.LEVELEDSHE)
crypto_context.Enable(fhe.ADVANCEDSHE)
# Print parameters
ring_dim = crypto_context.GetRingDimension()
q = crypto_context.GetModulus()
q_bitlength = int(q).bit_length()
print(f"\tCiphertext modulus bitsize (log2 q) = {q_bitlength}")
print(f"\tRing Dimension (m) = {crypto_context.GetRingDimension()}")
print()
start_time = timeit.default_timer()
# Generate Keys
key_pair = crypto_context.KeyGen()
crypto_context.EvalMultKeyGen(key_pair.secretKey)
crypto_context.EvalRotateKeyGen(key_pair.secretKey, [1, 2, 4])
keygen_time = timeit.default_timer() - start_time
data = [1, 2, 3, 4, 5]
archive_name = "archive.zip"
total_ciphs = 10
pack_time = 0
encrypt_time = 0
serialize_time = 0
archive_time = 0
total_pack_tic = timeit.default_timer()
# Using zip (without compression)
print(f"Packing, Encrypting, Serializing and Archiving {total_ciphs} Ciphertexts...")
with zipfile.ZipFile(archive_name, 'w', COMPRESSION_METHOD) as zipf:
for i_ciph in range(total_ciphs):
# Display progress
if i_ciph % 100 == 0:
print(f"ciphertext = {i_ciph}/{total_ciphs} ({i_ciph / total_ciphs * 100:.2f} %)", end='\r', flush=True)
# Pack
start_time = timeit.default_timer()
data_packed = crypto_context.MakePackedPlaintext(data)
pack_time = timeit.default_timer() - start_time
# Encrypt
start_time = timeit.default_timer()
ciph = crypto_context.Encrypt(key_pair.publicKey, data_packed)
encrypt_time += timeit.default_timer() - start_time
# Serialize
start_time = timeit.default_timer()
serial_ciph = fhe.Serialize(ciph, FILE_TYPE)
serialize_time += timeit.default_timer() - start_time
# Archive
start_time = timeit.default_timer()
ciph_filename = f'ciphertext_{i_ciph}.ciph'
zipf.writestr(ciph_filename, serial_ciph)
archive_time += timeit.default_timer() - start_time
total_pack_time = timeit.default_timer() - total_pack_tic
print()
print(f"Packing took {pack_time:.2f} seconds")
print(f"Encrypting took {encrypt_time:.2f} seconds")
print(f"Serializing took {serialize_time:.2f} seconds")
print(f"Archiving took {archive_time:.2f} seconds")
print(f"Total time took {total_pack_time:.2f} seconds")
print()
#####################################################################################
extract_time = 0
deserialize_time = 0
decrypt_time = 0
unpack_time = 0
total_start_time = timeit.default_timer()
print(f"Extracting, Deserializing, Decrypting and Unpacking {total_ciphs} Ciphertexts...")
with zipfile.ZipFile(archive_name, 'r') as zipf:
for i_ciph, file_name in enumerate(zipf.namelist()):
# Display progress
if i_ciph % 100 == 0:
print(f"ciphertext = {i_ciph}/{total_ciphs} ({i_ciph / total_ciphs * 100:.2f} %)", end='\r', flush=True)
# Extract
start_time = timeit.default_timer()
serialized_ciph = zipf.read(file_name)
extract_time += timeit.default_timer() - start_time
# Deserialize
start_time = timeit.default_timer()
ciph = fhe.DeserializeCiphertextString(serialized_ciph, FILE_TYPE)
deserialize_time += timeit.default_timer() - start_time
# Decrypt
start_time = timeit.default_timer()
res = crypto_context.Decrypt(ciph, key_pair.secretKey)
decrypt_time += timeit.default_timer() - start_time
# Unpack
start_time = timeit.default_timer()
res.SetLength(ring_dim)
unpacked_result = res.GetPackedValue()
unpack_time += timeit.default_timer() - start_time
total_unp_time = timeit.default_timer() - total_start_time
print()
print(f"Extracting took {extract_time:.2f} seconds")
print(f"Deserializing took {deserialize_time:.2f} seconds")
print(f"Decrypting took {decrypt_time:.2f} seconds")
print(f"Unpacking took {unpack_time:.2f} seconds")
print(f"Total time took {total_unp_time:.2f} seconds")
print()