diff --git a/cuda/ckks.cu b/cuda/ckks.cu index ff47c80a5c9bdf28f501b9c586bb1bd9181e33bf..8181a29b2b4c68cb611702959fe394d29069686b 100644 --- a/cuda/ckks.cu +++ b/cuda/ckks.cu @@ -78,7 +78,28 @@ int getNextPow(int deg){ } Ciphertext CKKS::multNTT(Ciphertext ct1, Ciphertext ct2){ - level += 1; - Ciphertext out = mult(ct1, ct2); + int deg = getNextPow(ct1.c0.degree+ct2.c0.degree-1); + int64_t w = ntt.genNthRoot(ntt.M, deg); + + Polynomial d1 = fastMult(ct1.c0, ct2.c0, deg, w).modCoeff(ql); + Polynomial d2 = (fastMult(ct1.c0, ct2.c1,deg, w) + fastMult(ct2.c0, ct1.c1,deg, w)).modCoeff(ql); + Polynomial d3 = fastMult(ct1.c1, ct2.c1,deg, w).modCoeff(ql); + + // Relin + Polynomial d3_0 = (d3 * evk.b).scaleRoundCoeff(1.0/1000.0); + Polynomial d3_1 = (d3 * evk.a).scaleRoundCoeff(1.0/1000.0); + + Polynomial outC0 = d1.modCoeff(ql) + d3_0.modCoeff(ql); + Polynomial outC1 = d2.modCoeff(ql) + d3_1.modCoeff(ql); + + // Rescale + ql = (double)ql / (double)pl[level-1]; + Polynomial c0 = outC0.scaleRoundCoeff(1.0/(double)pl[level-1]); + Polynomial c1 = outC1.scaleRoundCoeff(1.0/(double)pl[level-1]); + + level -= 1; + + Ciphertext out(c0.modCoeff(ql), c1.modCoeff(ql)); + return out; } \ No newline at end of file diff --git a/src/ckks.cpp b/src/ckks.cpp index fb2b43f8df1275fae460c1deccb8ce7d1985ab62..a7b04d4967999237636e723a977206780512c706 100644 --- a/src/ckks.cpp +++ b/src/ckks.cpp @@ -145,7 +145,7 @@ int getNextPow(int deg){ Ciphertext CKKS::multNTT(Ciphertext ct1, Ciphertext ct2){ int deg = getNextPow(ct1.c0.degree+ct2.c0.degree-1); - int64_t w = 1592366214; //ntt.genNthRoot(ntt.M, deg); + int64_t w = ntt.genNthRoot(ntt.M, deg); Polynomial d1 = fastMult(ct1.c0, ct2.c0, deg, w).modCoeff(ql); Polynomial d2 = (fastMult(ct1.c0, ct2.c1,deg, w) + fastMult(ct2.c0, ct1.c1,deg, w)).modCoeff(ql);