From a775c6d62105953fc37d68c803fb71c25002b48d Mon Sep 17 00:00:00 2001
From: DanielMariooR <13519031@std.stei.itb.ac.id>
Date: Wed, 14 Jun 2023 15:57:39 +0700
Subject: [PATCH] fix mult

---
 cuda/ckks.cu | 25 +++++++++++++++++++++++--
 src/ckks.cpp |  2 +-
 2 files changed, 24 insertions(+), 3 deletions(-)

diff --git a/cuda/ckks.cu b/cuda/ckks.cu
index ff47c80..8181a29 100644
--- a/cuda/ckks.cu
+++ b/cuda/ckks.cu
@@ -78,7 +78,28 @@ int getNextPow(int deg){
 }
 
 Ciphertext CKKS::multNTT(Ciphertext ct1, Ciphertext ct2){
-  level += 1;
-  Ciphertext out = mult(ct1, ct2);
+  int deg = getNextPow(ct1.c0.degree+ct2.c0.degree-1);
+  int64_t w = ntt.genNthRoot(ntt.M, deg);
+
+  Polynomial d1 = fastMult(ct1.c0, ct2.c0, deg, w).modCoeff(ql);
+  Polynomial d2 = (fastMult(ct1.c0, ct2.c1,deg, w) + fastMult(ct2.c0, ct1.c1,deg, w)).modCoeff(ql);  
+  Polynomial d3 = fastMult(ct1.c1, ct2.c1,deg, w).modCoeff(ql);
+
+  // Relin 
+  Polynomial d3_0 = (d3 * evk.b).scaleRoundCoeff(1.0/1000.0);
+  Polynomial d3_1 = (d3 * evk.a).scaleRoundCoeff(1.0/1000.0);
+
+  Polynomial outC0 = d1.modCoeff(ql) + d3_0.modCoeff(ql);
+  Polynomial outC1 = d2.modCoeff(ql) + d3_1.modCoeff(ql);
+
+  // Rescale
+  ql = (double)ql / (double)pl[level-1];
+  Polynomial c0 = outC0.scaleRoundCoeff(1.0/(double)pl[level-1]);
+  Polynomial c1 = outC1.scaleRoundCoeff(1.0/(double)pl[level-1]);
+  
+  level -= 1;
+
+  Ciphertext out(c0.modCoeff(ql), c1.modCoeff(ql));
+
   return out;
 }
\ No newline at end of file
diff --git a/src/ckks.cpp b/src/ckks.cpp
index fb2b43f..a7b04d4 100644
--- a/src/ckks.cpp
+++ b/src/ckks.cpp
@@ -145,7 +145,7 @@ int getNextPow(int deg){
 
 Ciphertext CKKS::multNTT(Ciphertext ct1, Ciphertext ct2){
   int deg = getNextPow(ct1.c0.degree+ct2.c0.degree-1);
-  int64_t w = 1592366214; //ntt.genNthRoot(ntt.M, deg);
+  int64_t w = ntt.genNthRoot(ntt.M, deg);
 
   Polynomial d1 = fastMult(ct1.c0, ct2.c0, deg, w).modCoeff(ql);
   Polynomial d2 = (fastMult(ct1.c0, ct2.c1,deg, w) + fastMult(ct2.c0, ct1.c1,deg, w)).modCoeff(ql);  
-- 
GitLab