Dijkstra’s Algorithm for Adjacency Matrix

Jude Capachietti
6 min readDec 18, 2021

Ahh Dijkstra’s algorithm. It’s a classic algorithm, and every time I find myself needing to code it up again, I always feel like I am starting from square one. This is because many of the resources explaining Dijkstra’s algorithm on the internet are either unclear, incomplete, just plain wrong, or the code is for dictionary representations of a graph and I am dealing with an adjacency matrix.

I am endeavoring to fix this problem for both future me and others seeking to understand this algorithm and code it quickly instead of sifting through the interwebs for ages trying to pin down how to approach it.

Note: I am referring Nodes or Vertices in graph theory as Points in this article because I don’t like the words vertix or node .

First off, here is a YouTube video that I found very helpful, and not too long: https://www.youtube.com/watch?v=pVfj6mxhdMw

So, let’s start from the beginning (and quickly move past it).

Q: What is Dijkstra’s algorithm?

A: An algorithm to find the shortest (cheapest, etc…) path from a start point to an end point on a graph.

Q: What does it do?

A: It finds the shortest path to every other point on the graph, and one of those points is the end point (or it is not possible to reach the end point). It also saves the previous point that was traveled on the shortest path to get to a point, so that the shortest path taken to get to a point can be found by stringing together the previous point taken until the starting point is reached.

Q: How does it find the shortest path to every other point?

A: This is accomplished by calculating the distance to every point from the starting point. To do this we traverse the graph by going to the next closest point from the starting point that is unvisited. At this point, we add all points that can be traversed from this point and their distances from the starting point on this path to our priority queue of next points to look at, then mark this point as visited. Then we traverse to the next closest point from the starting point that is unvisited, and so on until the whole graph has been visited.

Q: How do we know that when a point has been marked as visited that the shortest path to get there has already been taken to get there?

A: Because we are using a min-heap or a min priority queue to determine which point to traverse to next based on the shortest distance from the starting point, we know that when we arrive at a point, we will always have taken the shortest path to get there. Also, there may be other paths to get to this in the min-heap/min priority queue, but they will not be traversed from that path this point will already be marked as visited by then.

So enough with the Q&A. Before I dump the code here, I want to outline assumptions that this implementation is making:

  • all points in the adjacency matrix that are next to each other are connected. That is, there are no “blocked” spaces/all points in the adjacency matrix can be traversed. So the stopping criteria for the while loop is that the number of visited points equals the number of total points. If there are blocked spaces in your problem, you will need to count them and remove them from the number of traversable points and mark them as already visited (there might be other/better ways to do this)
  • the directions of travel are up, down, left, right. These are reflected as tuples in directions and are added to a point to get the surrounding points around it. If diagonals are to be included then add (1,1),(1,-1),(-1,-1),(-1,1) to directions.

Also, it is worth noting that when a point is added to the min-heap/min priority queue it is added as a tuple in the form of (distance from starting point to this point on this path, point's row, point's column). There are 3 values all at the same depth (even though it would seem right for the row and column to be in their own tuple) to allow for tie-breaking in the min-heap based if two points have the same distance from the starting point.

To start I created a function that takes a graph and the starting_point (tuple) as an input and returns 2 graphs of the same shape: one that has the distance of the shortest path to a point from the starting point for every point on the graph, and the other that has the previous point what was taken on the shortest path to get to this point for each point.

Here’s the code:

def find_shortest_paths(graph, start_point):
# initialize graphs to track if a point is visited,
# current calculated distance from start to point,
# and previous point taken to get to current point
visited = [[False for col in row] for row in graph]
distance = [[float('inf') for col in row] for row in graph]
distance[start_point[0]][start_point[1]] = 0
prev_point = [[None for col in row] for row in graph]
n, m = len(graph), len(graph[0])
number_of_points, visited_count = n * m, 0
directions = [(0, 1), (1, 0), (-1, 0), (0, -1)]
min_heap = []

# min_heap item format:
# (pt's dist from start on this path, pt's row, pt's col)
heapq.heappush(min_heap, (distance[start_point[0]][start_point[1]], start_point[0], start_point[1]))

while visited_count < number_of_points:
current_point = heapq.heappop(min_heap)
distance_from_start, row, col = current_point
for direction in directions:
new_row, new_col = row + direction[0], col + direction[1]
if -1 < new_row < n and -1 < new_col < m and not visited[new_row][new_col]:
dist_to_new_point = distance_from_start + graph[new_row][new_col]
if dist_to_new_point < distance[new_row][new_col]:
distance[new_row][new_col] = dist_to_new_point
prev_point[new_row][new_col] = (row, col)
heapq.heappush(min_heap, (dist_to_new_point, new_row, new_col))
visited[row][col] = True
visited_count += 1

return distance, prev_point

Let’s say you have a graph that looks like this:

graph = [
[1, 1, 6, 3, 7, 5, 1, 7, 4, 2],
[1, 3, 8, 1, 3, 7, 3, 6, 7, 2],
[2, 1, 3, 6, 5, 1, 1, 3, 2, 8],
[3, 6, 9, 4, 9, 3, 1, 5, 6, 9],
[7, 4, 6, 3, 4, 1, 7, 1, 1, 1],
[1, 3, 1, 9, 1, 2, 8, 1, 3, 7],
[1, 3, 5, 9, 9, 1, 2, 4, 2, 1],
[3, 1, 2, 5, 4, 2, 1, 6, 3, 9],
[1, 2, 9, 3, 1, 3, 8, 5, 2, 1],
[2, 3, 1, 1, 9, 4, 4, 5, 8, 1],
]

To find the shortest distance from (0,0) to all points on the graph you would do:

distance, prev_point = find_shortest_paths(graph, (0, 0))

distance would look like:

[[0, 1, 7, 10, 17, 22, 23, 30, 34, 36],
[1, 4, 12, 11, 14, 21, 23, 29, 32, 34],
[3, 4, 7, 13, 18, 19, 20, 23, 25, 33],
[6, 10, 16, 17, 26, 22, 21, 26, 31, 38],
[13, 14, 20, 20, 24, 23, 28, 27, 28, 29],
[14, 17, 18, 27, 25, 25, 33, 28, 31, 36],
[15, 18, 23, 32, 34, 26, 28, 32, 33, 34],
[18, 19, 21, 26, 30, 28, 29, 35, 36, 43],
[19, 21, 30, 29, 30, 31, 37, 40, 38, 39],
[21, 24, 25, 26, 35, 35, 39, 44, 46, 40]]

And prev_point would look like:

[ None,(0,0),(0,1),(0,2),(0,3),(0,4),(0,5),(0,6),(0,7),(0,8)]
[(0,0),(0,1),(1,1),(0,3),(1,3),(1,4),(2,6),(1,6),(2,8),(1,8)]
[(1,0),(2,0),(2,1),(2,2),(2,3),(2,4),(2,5),(2,6),(2,7),(2,8)]
[(2,0),(2,1),(2,2),(2,3),(3,3),(2,5),(2,6),(3,6),(2,8),(4,9)]
[(3,0),(3,1),(4,1),(3,3),(4,3),(3,5),(3,6),(3,7),(4,7),(4,8)]
[(4,0),(4,1),(5,1),(5,2),(4,4),(4,5),(5,5),(4,7),(4,8),(4,9)]
[(5,0),(6,0),(5,2),(6,2),(5,4),(5,5),(6,5),(5,7),(5,8),(6,8)]
[(6,0),(6,1),(7,1),(7,2),(7,3),(6,5),(6,6),(7,6),(6,8),(6,9)]
[(7,0),(7,1),(7,2),(7,3),(8,3),(7,5),(7,6),(7,7),(7,8),(8,8)]
[(8,0),(8,1),(9,1),(9,2),(9,3),(8,5),(9,5),(9,6),(8,8),(8,9)]

If you wanted to find the distance of the shortest path between (0,0) and (9,9) , you would look at:

print(distance[9][9])

If you wanted to find the distance of the shortest path between (0,0) and (3,6) , you would look at:

print(distance[3][6])

Now if you wanted to find the order of points you would take on the shorted path to get to (9,9) or (3,6) , you would want another function like this to help you find it:

def find_shortest_path(prev_point_graph, end_point):
shortest_path = []
current_point = end_point
while current_point is not None:
shortest_path.append(current_point)
current_point = prev_point_graph[current_point[0]][current_point[1]]
shortest_path.reverse()
return shortest_path

And you would run:

print(find_shortest_path(prev_point, (9, 9)))

or

print(find_shortest_path(prev_point, (3, 6)))

I hope this has helped make this key algorithm more palatable. :)

--

--

Jude Capachietti

I am an adventurous person, aiming to live life with minimal regrets. My interests include software engineering, video games, & Data!