diff --git a/djikstra_cuda.cu b/djikstra_cuda.cu
index e4b0a24c6440b8c1a4f81166ed5630f36570b854..6380533ccebe110e69f4a67ab4c8d80bca017003 100644
--- a/djikstra_cuda.cu
+++ b/djikstra_cuda.cu
@@ -34,10 +34,13 @@ void printGraph(int n, int graph[n]){
 }
 
 __device__ 
-void djikstra(int graph[],int start,int distance[],int visited[],int n,int answer[]){
+void djikstra(int graph[],int start,int n,int answer[]){
 
     int node;
     int temp_dist;
+    int *distance,*visited;
+    distance = malloc(n*sizeof(int));
+    visited = malloc(n*sizeof(int));
     int min_dist[2]; // node and its distance
     min_dist[0] = -1;
     min_dist[1] = INFINITY;
@@ -78,18 +81,19 @@ void djikstra(int graph[],int start,int distance[],int visited[],int n,int answe
     }
     for(int i=0;i<n;i++){
         answer[start*n+i] = distance[i];
-
     }
+    free(distance);
+    free(visited);
 }
 
 __global__
-void djikstra_cuda(int graph[]int distance[],int visited[],int n,int answer[])
+void djikstra_cuda(int graph[],int n,int answer[])
 {
     int index = blockIdx.x * blockDim.x + threadIdx.x;
     int stride = blockDim.x * gridDim.x;
     for (int i = index; i < n; i += stride){
-        djikstra(graph,i,distance,visited[],n,answer);
-  }
+        djikstra(graph,i,n,answer);
+    }
 }
 
 int main(int argc, char** argv){
@@ -111,8 +115,6 @@ int main(int argc, char** argv){
 
     int *graph,*distance,*visited, *answer;
     graph = malloc(n*n*sizeof(int));
-    distance = malloc(n*sizeof(int));
-    visited = malloc(n*sizeof(int));
     answer = malloc(n*n*sizeof(int));
     createGraph(n,graph);
     fprintf(output_file, "GRAPH:\n");
@@ -128,7 +130,7 @@ int main(int argc, char** argv){
     int numBlocks = (n + blockSize - 1) / blockSize;
 
     gettimeofday(&st,NULL);
-    djikstra_cuda<<<numBlocks, blockSize>>>>(graph,distance,visited,n,answer);
+    djikstra_cuda<<<numBlocks, blockSize>>>>(graph,n,answer);
     gettimeofday(&et,NULL);
     unsigned long long int elapsed = ((et.tv_sec - st.tv_sec) * 1000000) + (et.tv_usec - st.tv_usec);
     printf("Djikstra total time: %llu micro seconds\n",elapsed);