diff --git a/bin/ntt.o b/bin/ntt.o new file mode 100644 index 0000000000000000000000000000000000000000..40679a382429fbc1a6bdbe246ab1e757cb8db3aa Binary files /dev/null and b/bin/ntt.o differ diff --git a/bin/seckey.o b/bin/seckey.o index a4cfdc38f9f44fbab072cae7adc4365fde3b5734..257a327b70529a10066ffeb0315e9c7bb497942d 100644 Binary files a/bin/seckey.o and b/bin/seckey.o differ diff --git a/comptest.sh b/comptest.sh index a5908bd2734b83a8a9bc22118474b58a7b90f29a..c72f03683d76e76102f3ec3f05693a8a4e53d1cf 100755 --- a/comptest.sh +++ b/comptest.sh @@ -1,4 +1,4 @@ #!/bin/bash g++ -c -o test.o test.cpp -g++ -o run test.o ./bin/ciphertext.o ./bin/ckks.o ./bin/encoder.o ./bin/polynomial.o ./bin/pubkey.o ./bin/seckey.o ./bin/evalkey.o \ No newline at end of file +g++ -o run test.o ./bin/ciphertext.o ./bin/ckks.o ./bin/encoder.o ./bin/polynomial.o ./bin/pubkey.o ./bin/seckey.o ./bin/evalkey.o ./bin/ntt.o \ No newline at end of file diff --git a/run b/run index 5d77c5dfd8013fe1aa2b97ea64d0029271bd1f2b..f8a7d286a7d0f4d0e133f53be7216f3c2821f13e 100755 Binary files a/run and b/run differ diff --git a/src/ckks.cpp b/src/ckks.cpp index 30e7d9d826eae96b9eb38a34047ae6e434d2e191..ec67b53929e96b45d204241282545ff5a06028c1 100644 --- a/src/ckks.cpp +++ b/src/ckks.cpp @@ -73,7 +73,7 @@ Ciphertext CKKS::mult(Ciphertext ct1, Ciphertext ct2){ Polynomial d2 = ((ct1.c0 * ct2.c1) + (ct2.c0 * ct1.c1)).modCoeff(ql); Polynomial d3 = (ct1.c1 * ct2.c1).modCoeff(ql); - // Relin --> Still weird + // Relin Polynomial d3_0 = (d3 * evk.b).scaleRoundCoeff(1.0/1000.0); Polynomial d3_1 = (d3 * evk.a).scaleRoundCoeff(1.0/1000.0); diff --git a/src/compile.sh b/src/compile.sh index 0bb493acf053a07abd5bd82ac9246398f42d056b..0c7adde5b9bf274114e1632fffa6412ab2463d1a 100755 --- a/src/compile.sh +++ b/src/compile.sh @@ -7,3 +7,4 @@ g++ -c -o ../bin/polynomial.o polynomial.cpp g++ -c -o ../bin/pubkey.o pubkey.cpp g++ -c -o ../bin/seckey.o seckey.cpp g++ -c -o ../bin/evalkey.o evalkey.cpp +g++ -c -o ../bin/ntt.o ntt.cpp diff --git a/src/ntt.cpp b/src/ntt.cpp index ea46e39f090b8533c454fb30006697c469a0c780..40be4f1d7c70eeed5024da63d00be543bf5578ce 100644 --- a/src/ntt.cpp +++ b/src/ntt.cpp @@ -18,41 +18,138 @@ int NTT::modExp(int base, int power, int mod){ int NTT::modInv(int x, int mod){ int t = 0; int t1 = 1; - int r, r1 = mod, x; + int r = mod; + int r1 = x; while (r1 != 0){ int quot = (int) (r/r1); - int temp_t1 = t1; int temp_t = t; - int temp_r1 = r1; int temp_r = r; - t, t1 = t1, (t - quot * t1); - r, r1 = r1, (r % r1); + t = t1; + t1 = (temp_t - quot * t1); + r = r1; + r1 = (temp_r % r1); } - - if (t<0){ + + if (t < 0){ t = t + mod; } return t; } -void NTT::_ntt(vector<int64_t> &in, bool inverse){ +int NTT::bitLength(int x){ + int a = x; + int len = 0; + while (a!=0){ + a >>= 1; + len += 1; + } + return len; +} + +bool NTT::existSmallerN(int r, int mod, int n){ + for(int k=2; k<n; k++){ + if(modExp(r, k, mod) == 1){ + return true; + } + } + return false; +} + +int64_t NTT::genNthRoot(int mod, int n){ + int p = mod - 1; + int range = mod - 1 + 1; + while (true){ + int64_t a = rand() % range + 1; + int64_t b = modExp(a, p/n, mod); + if (!existSmallerN(b, mod, n)){ + return b; + } + } +} + +void NTT::reverse(vector<int64_t> &in, int bitLen){ + for (int i=0; i<size(in); i++){ + int revN = 0; + for(int j=0; j<bitLen ; j++){ + if ((i >> j) & 1){ + revN |= 1 << (bitLen-1-j); + } + } + int coeff = in[i]; + + if (revN > i){ + coeff ^= in[revN]; + in[revN] ^= coeff; + coeff ^= in[revN]; + in[i] = coeff; + } + } +} + +void NTT::_ntt(vector<int64_t> &in, int64_t w){ + int N = size(in); + int nBit = bitLength(N) - 1; + reverse(in, nBit); + + vector<int> points(N, 0); + for(int i=0; i<nBit; i++){ + vector<int64_t> p1; + vector<int64_t> p2; + for(int j=0; j<N/2; j++){ + int shift = nBit - i - 1; + int P = (j >> shift) << shift; + int wP = modExp(w, P, M); + int64_t odd = in[2*j+1] * wP; + int64_t even = in[2*j]; + p1.push_back((even + odd) % M); + p2.push_back((even - odd) % M); + } + + for(int k=0; k<N/2; k++){ + points[k] = p1[k]; + points[k+N/2] = p2[k]; + } + + if(1!=nBit){ + for(int k=0; k<N; k++){ + in[k] = points[k]; + } + } + } + for(int k=0; k<N; k++){ + in[k] = points[k]; + if(in[k] < 0){ + in[k] += M; + } + } } -vector<int64_t> NTT::ntt(Polynomial in){ - vector<int64_t> out(in.degree,0); - _ntt(out, true); +vector<int64_t> NTT::ntt(Polynomial in, int degree, int64_t w){ + vector<int64_t> out(degree,0); + for(int i=0; i<degree; i++){ + out[i] = (int64_t) in.coeffs[i]; + } + _ntt(out, w); return out; } -Polynomial NTT::intt(vector<int64_t> in){ - vector<int64_t> out(size(in),0); - Polynomial pOut(size(out)); - int n_inv = modInv(size(out), M); +Polynomial NTT::intt(vector<int64_t> &in, int w){ + int N = size(in); + Polynomial pOut(N); + double coeff[N]; + + int wInv = modInv(w, M); + int nInv = modInv(N, M); - _ntt(out,false); + _ntt(in, wInv); + + for(int i=0; i<size(in); i++){ + coeff[i] = (in[i] * nInv) % M; + } + pOut.setCoeffs(coeff); return pOut; } \ No newline at end of file diff --git a/src/ntt.h b/src/ntt.h index 5f3646928ed19ba18c7fd9b0a0e3fb3805d3443c..74a6110ac3c456c214e46a32ceb1afea44bd1751 100644 --- a/src/ntt.h +++ b/src/ntt.h @@ -12,12 +12,16 @@ using namespace std; class NTT{ public: int g = 3; - int M = 998244353; + int M = 2013265921; int modExp(int base, int power, int mod); int modInv(int x, int mod); - void _ntt(vector<int64_t> &in, bool inverse); - vector<int64_t> ntt(Polynomial in); - Polynomial intt(vector<int64_t> in); + int bitLength(int x); + void reverse(vector<int64_t> &in, int bitLen); + bool existSmallerN(int r, int mod, int n); + int64_t genNthRoot(int mod, int n); + void _ntt(vector<int64_t> &in, int64_t w); + vector<int64_t> ntt(Polynomial in, int degree, int64_t w); + Polynomial intt(vector<int64_t> &in, int w); }; diff --git a/src/polynomial.h b/src/polynomial.h index 04e8de8d713e73e38ee069131d91cbcf4cdfd1ca..e1322333c7fc0c48318b7f977c04217ac7fb36c6 100644 --- a/src/polynomial.h +++ b/src/polynomial.h @@ -1,3 +1,4 @@ +#include<iostream> #ifndef POLYNOMIAL_H #define POLYNOMIAL_H diff --git a/test.cpp b/test.cpp index bab3c670b54c25ccdf1a266938e22f2bf6f41ad3..ce8fe33fb82f8ad660557f307523923d57e547e6 100644 --- a/test.cpp +++ b/test.cpp @@ -2,6 +2,8 @@ #include "src/encoder.h" #include "src/polynomial.h" #include "src/ckks.h" +#include "src/ntt.h" +#include <vector> #include <string> @@ -57,5 +59,22 @@ int main(){ } cout << endl; + cout << "NTT test" << endl; + NTT ntt; + int64_t w = 1728404513; + double coeff[] = {2, 1, 0, 0}; + Polynomial n(4, coeff); + + vector<int64_t> test = ntt.ntt(n, 4, w); + cout <<"test"<<endl; + for(int i=0;i<4;i++){ + cout << test[i] << " "; + } + cout << endl; + + cout <<"test"<<endl; + Polynomial testOut = ntt.intt(test, w); + testOut.printPol(); + return 0; } \ No newline at end of file diff --git a/test.o b/test.o index 1fd77b778c7af00a5f6a8ead0a26b90c439a17f1..4a9916dc459ad9e40041cc478df7c8e32c0585b8 100644 Binary files a/test.o and b/test.o differ