diff --git a/dijkstra_cuda.cu b/dijkstra_cuda.cu
index e94fc08ba21f5c45e89a33a07437b32e10175e04..210cb181a8a50b94fa7984f584b4f8dd05db0145 100644
--- a/dijkstra_cuda.cu
+++ b/dijkstra_cuda.cu
@@ -75,15 +75,14 @@ __device__ int minDistance(int *dist, bool sptSet[], int n)
     return min_index; 
 } 
 
-__device__ void dijkstra(int **graph, int **result, int src, int n) 
+__device__ void dijkstra(int **graph, int **result, bool *sptSet, int src, int n) 
 { 
     int *dist; // The output array.  dist[i] will hold the shortest 
     // distance from src to i
     dist = result[src];
 
-    bool *sptSet; // sptSet[i] will be true if vertex i is included in shortest 
+    // sptSet[i] will be true if vertex i is included in shortest 
     // path tree or shortest distance from src to i is finalized
-    sptSet = (bool *)malloc(sizeof(bool) * n);
 
     // Initialize all distances as INFINITE and stpSet[] as false 
     for (int i = 0; i < n; i++) 
@@ -116,13 +115,13 @@ __device__ void dijkstra(int **graph, int **result, int src, int n)
 /*
  * All Pairs Shortest Path with dijkstra algorithm
  */
-__global__ void dijkstra_APSP(int **matrix, int **result, int n) {
+__global__ void dijkstra_APSP(int **matrix, int **result, bool *sptSet, int n) {
 
     int index = blockIdx.x * blockDim.x + threadIdx.x;
     int stride = blockDim.x * gridDim.x;
 
     for (int i = index; i < n; i += stride)
-        dijkstra(matrix, result, i, n);
+        dijkstra(matrix, result, &sptSet[i*n], i, n);
 }
 
 int main(int argc, char **argv){
@@ -145,12 +144,15 @@ int main(int argc, char **argv){
     int **result;
     malloc_matrix(&result, n);
 
+    bool *sptSet;
+    cudaMallocManaged(&sptSet, sizeof(bool) * n * n);
+
     int blockSize = 256;
     int numBlocks = (n + blockSize - 1) / blockSize;
 
     cudaEventRecord(start);
     // Start dijkstra APSP
-    dijkstra_APSP<<<12, blockSize>>>(matrix, result, n);
+    dijkstra_APSP<<<12, blockSize>>>(matrix, result, sptSet, n);
     cudaEventRecord(stop);
 
     cudaEventSynchronize(stop);
@@ -175,4 +177,4 @@ int main(int argc, char **argv){
         printf("CUDA error: %s\n", cudaGetErrorString(err));
 
     return 0;
-}
\ No newline at end of file
+}