acatalepsie/content/posts/haskell-dijkstra.lhs.md

17 KiB

title date draft
Generalized Dijkstra in Haskell 2024-12-20 true

This years' Advent of Code has lots of 2D grids, and makes you traverse them to find paths of various kinds. At some point I had to implement Dijkstra's algorithm, in Haskell. In trying to make my implementation reusable for the following days, I realized that Dijkstra's algorithm is actually way more general than I remembered (or was taught)! In short, weights don't have to be real-valued!

In this post, I describe a general interface for the algorithm, such that we can implement it exactly once and use it to compute many different things.

This article is a literate Haskell file, so feel free to download it and try it for yourself! As such, let's get a few imports and language extensions out of the way:

<details>
  <summary>Haskell Bookkeeping</summary>
{-# LANGUAGE GHC2021 #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE DeriveAnyClass #-}

import Control.Monad (when, foldM)
import Control.Monad.ST (ST, runST)
import Data.Ix (Ix, inRange)
import Data.Array (Array, (!), listArray)
import Data.Array qualified as Array (bounds)
import Data.Array.MArray (newArray, freeze, readArray, writeArray)
import Data.Array.ST (STArray)
import Data.Function ((&))
import Data.Functor ((<&>))
import Data.Kind (Type)
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Semigroup (Sum(Sum))
</details>

A primer on Dijkstra's algorithm

Let's recall the problem statement that Dijkstra's algorithm solves, and the pseudo-code of the algorithm. If you know that already, you can skip to the next section.

The shortest path problem

It's been a while since I had to use any formalism to talk about graphs proper. So I will be using the notations from the Cormen book that I just looked up for a refresher.

Consider a weighted directed graph G = (V, E, w).

  • V denotes the set of vertices.
  • E \subseteq V \times V denotes the set of edges.
  • Every edge e \in E has an associated (non-negative) weight w(e) \in \mathbb{R}, w(e) > 0.

We call path a sequence of vertices such that there is an edge between every consecutive vertex in the sequence. If we denote \text{paths}(a, b) the set of paths from a to b in G, this means $p = \langle v_0, \dots, v_k \rangle \in \text{paths}(a, b) if $v_0 = a, v_k = b and \forall 0 ≤ i < k, (v_i, v_{i + 1}) \in E.

We can define the weight of a path as the sum of the weights of its constituent edges.

w(p) = \sum_{i = 1}^k{w(v_{i - 1}, v_i)}

Shortest-paths problems ask questions along the lines of:

  • What is the minimum weight from a to b in G?
  • What is one path from a to b with minimum weight?
  • Can we find all such paths from a to b?

If you interpret the weight as a physical distance, this amounts to finding the shortest trip from one vertex to the other.

Dijkstra's algorithm is an infamous technique for solving the single-source shortest-paths problem: finding a shortest path from a given source vertex s to every other vertex v \in V. It's essentially a generalization of breadth-first search to (positively) weighted graphs. And it's pretty fast!

Dijkstra's algorithm

TODO: high-level overview of the algorithm


Taking a step back, generalizing

One thing to notice in the problem statement from earlier is that weights have very little to do with real numbers. In fact, they don't have to be scalars at all!

If we denote W the set of weights, the algorithm merely requires:

  • An equivalence relation (\cdot \approx \cdot) \subseteq W \times W on weights.
    It doesn't have to be definitional equality!
  • A total order on weights, that is: (\cdot\leq\cdot) \subseteq W \times W such that it is transitive, reflexive and anti-symmetric. This order should be compatible with \approx i.e. equivalence preserves order.
  • A way to add weights together $(\cdot \oplus \cdot) ∷ W \rightarrow W \rightarrow W$, such that:
    • \oplus is associative.
    • \approx is compatible with \oplus.
    • \leq is compatible with \oplus.
    • x \oplus y is an upper bound of both x and y, i.e "adding costs together can only increase the total cost".
  • A neutral element 0 for \oplus, that should also be a lower bound of W.
  • An absorbing element \infty for \oplus, that should also be an upper bound of W.

If we summarize, it looks like (W/\approx, \oplus, 0) should be a monoid, totally ordered by \leq and with null element \infty. I think this encompasses all the properties stated above, and nothing more, but I haven't looked that deeply into the formalism and how mathematicians usually call these things.

The restriction that edges must have non-negative weights can simply be reworded as weights having to be strictly larger than the identity element.

\forall e \in E, w(e) > 0

Now we can freely redefine the weight of a path:


w(p) = \bigoplus_{i = 1}^k{w(v_{i - 1}, v_i)}

Equipped with this toolkit, we can state the single-source shortest-path problem again: for a given source vertex s \in V, how do we compute the smallest weight achievable on a path from s to any other vertex e \in V?


Abstract Haskell interface and implementation

Now that we've figured out the building blocks that are required for the algorithm to work, let's write this down in Haskell!

Weights

Given the requirements on weights we established earlier, we can try to map each of them to their corresponding Haskell counterpart.

  • Weights should have an equivalence relation: that's Eq.
  • Weights should have a total order: that's Ord.
  • Weights should have an associative addition operation that respects the order: that's Semigroup.

Sadly we're not using Agda so we can't enforce the fact that the order relation must be compatible with the semigroup operation from inside the language. We'll just have to be careful when defining the instances.

So, a Weight should therefore have instances for all three classes above (and Ord implies Eq in Haskell, somehow).

class (Semigroup a, Ord a) => Weight a where
  infty :: a

  updateWeight :: a -> a -> a
  updateWeight x = const x

infty is the absorbing element of W. As stated earlier, it must be an upper bound of W.

But what is this updateWeight operation here? It is used to merge equivalent weights. Indeed, during the execution of the Dijkstra algorithm, in the relaxation phase, we may find that the weight of going to v by passing through u is equal to the cost we have already computed for v. Because we haven't decreased the weight, we shouldn't update the priority of v in the queue, however it might still be useful to account for the new paths through u.

That's what this function is for. The only requirement for updateWeight is that the output should be in the same equivalence class as its (equivalent) inputs.

\forall w, w' \in W s.t. w \approx w', \texttt{updateWeight}(w, w') \approx w \approx w'

As a convention, the first argument is the already computed weight, and the second argument is the newly discovered (equivalent) cost along the new path(s) through u.

The priority queue should then update the weight of v to this new value. It won't change the priority of v in the queue, and the order of traversal, but the new information is now accounted for.

The default implementation for mergeWeight discards the new weight entirely. This is quite common, say, if we only want to find "a shortest path", and not every one of them.

Graphs

Now that we know what weights are, we need to describe what kind of graphs are suitable for our Dijkstra algorithm.

data Dijkstra i c = Dijkstra
    { bounds      :: (i, i)
    , startCost   :: i -> c
    , next        :: i -> c -> [(c, i)]
    }

So, let's expand a bit on the fields of the interface.

  • bounds describes the lower and upper bound of V. This is just an implementation detail: I want to store intermediate weights in a mutable array during the traversal, for efficiency purposes. So I need to know the size of V.

    If you cannot reasonnably enumerate all vertices, you can drop the bounds field and use a purely-functional persistent Map instead in the implementation.

  • initCost returns the initial cost we use for a given start vertex. It must always be an identity element of W, and a lower bound of W.

    \forall s \in V, w \in W, \texttt{startCost}(s) \oplus w \approx w \oplus \texttt{startCost}(s) \approx w
    \forall s \in V, w \in W, \texttt{startCost}(s) \leq w

    Concretely, this means that rather than have a single identity 0 \in W, we have one for every vertex. By anti-symmetry of the ordering relation they are all equivalent anyway. This is very useful to store information about the starting vertex in the weight. Say, if we're computing paths, we initially store a 0-length path containing only the starting vertex.

  • And finally, the bread and butter of the graph: a transition function next. For any vertex u and its associated weight w, next u w returns the neighbours of u, with the weight of the edges. As discussed earlier, weight of edges must be strictly larger than 0.

    One may wonder why we take as input the weight of u, and indeed it is weird. Most reasonable transition functions ignore it. But this means you can define funky graphs where the weight of an edge depends on the minimal weight to get there from a specific source. I think it is perfectly fine w.r.t the assumptions of the Dijkstra algorithm, though knowing exactly what kind of graph this corresponds to is a bit more tedious.

    I show one such example where I rely on this input weight later on.

And here we have it! A description of graphs that can serve as input for the Dijkstra algorithm to solve the single-source shortest-path problem.

Note that this interface is completely agnostic to how we encode our graphs, so long as we can extract a transition function from this underlying representation.

Generic Dijkstra implementation

Finally. It is time. We can implement the Dijkstra algorithm. But first we need a priority queue, with the following interface:

type PQueue :: Type -> Type

emptyQ     :: PQueue a
singletonQ :: Ord a => a -> PQueue a
insertQ    :: Ord a => a -> PQueue a -> PQueue a

pattern EmptyQ :: PQueue a
pattern (:<)   :: Ord a => a -> PQueue a -> PQueue a

For simplicity, let's just use a wrapper around Data.Set.

<details>
  <summary><code>PQueue</code> implementation</summary>
newtype PQueue a = PQueue (Set a)

emptyQ = PQueue Set.empty
singletonQ = PQueue . Set.singleton
insertQ x (PQueue s) = PQueue (Set.insert x s)

minView :: PQueue a -> Maybe (a, PQueue a)
minView (PQueue s) =
  case Set.minView s of
    Just (x, s') -> Just (x, PQueue s')
    Nothing      -> Nothing

pattern EmptyQ   <- (minView -> Nothing)
pattern (:<) x q <- (minView -> Just (x, q))
</details>

I haven't tried existing implementations available on Hackage yet, I should get around to it at some point. It also looks like I may want a priority search queue, so that I can really update the priority for a given key.

At last, the implementation for Dijkstra's algorithm:

dijkstra :: (Ix i, Weight c) => Dijkstra i c -> i -> Array i c
dijkstra (Dijkstra{..} :: Dijkstra i c) start = runST do
  costs :: STArray s i c <- newArray bounds infty
  let zero = startWeight start
  writeArray costs start zero
  let queue = singletonQ (zero, start)
  aux costs queue
  freeze costs

  where

  aux :: forall s. STArray s i c -> PQueue (c, i) -> ST s ()
  aux costs EmptyQ = pure ()
  aux costs ((_, u) :< queue) = do
    uWeight' <- readArray costs u
    -- because of how our pure PQueue works,
    -- we cannot really "update" the priority of an element in the queue
    -- instead, we just insert it again, with a lower priority
    -- so, if the cost just popped off the queue is larger than the one already known
    -- it's because we've already visited the node.
    when (uWeight == uWeight')

      let
        relaxNeighbour :: PQueue (c, i) -> (c, i) -> ST s (PQueue (c, i))
        relaxNeighbour !queue (uvWeight, v) = do
          let !vWeight = uWeight <> uvWeight
          vWeight' <- readArray costs v
          case vWeight `compare` vWeight' of
            GT -> pure queue -- going through u yields a higher cost to v
            EQ -> do -- same cost, we merge them
              writeArray costs v $ updateWeight vWeight' vWeight
              pure queue
            LT -> do -- going through u decreases the cost of v
              writeArray costs v vWeight
              pure $ insertQ (vWeight, v) queue

      in aux costs =<< foldM relaxNeighbour queue (next u uWeight)

Instanciating the interface

The interface for dijkstra's algorithm is very abstract now. Let's see how to instanciate it to compute useful information!

But first, we define a basic graph that we will traverse in all the following examples.

type Coord = (Int, Int)
type Grid  = Array Coord Char

neighbours :: Coord -> [Coord]
neighbours (x, y) =
    [ (x - 1, y    )
    , (x + 1, y    )
    , (x    , y - 1)
    , (x    , y + 1)
    ]

emptyCell :: Grid -> Coord -> Bool
emptyCell grid = ('#' /=) . (grid !)
graph :: Array Coord Char

Minimum distance

The simplest example is to try and compute the length of the shortest path between two vertices. We define our cost, in this case an integer for the length.

newtype MinDist = MinDist !Int deriving (Eq, Ord, Cost)

And then we instanciate the interface.

minDist :: Grid -> Dijkstra Coord MinDist
minDist grid = Dijkstra 
  { bounds      = IArray.bounds grid
  , initCost    = const (MinDist 0)
  , defaultCost = MinDist maxBound
  , next        = next
  }
  where next :: Coord -> MinDist -> [(MinDist, Coord)]
        next x (MinDist d) =
          neighbours x
          & filter (inRange (IArray.bounds grid))
          & filter ((/= '#') . (grid !))
          & map (MinDist (d + 1),)

getMinDist :: Grid -> Coord -> Coord -> MinDist
getMinDist = dijkstraTo . minDist

Shortest path

Maybe this interface has become too abstract, so let's see how to instanciate it to find the usual shortest path. We introduce a new cost that stores a path along with its length.

data Path i = Path !Int [i]

Given that we only want to find a shortest path, we can put paths with the same length in the same equivalence class, and compare paths only by looking at their length.

instance Eq  (Path i) where
  Path l1 _ == Path l2 _ = l1 == l2

instance Ord (Path i) where 
  Path l1 _ `compare` Path l2 _ = compare l1 l2

Again, because we only want to find a shortest path, if we merge two paths in the same equivalence class, we simply return the first one. Lucky for us, that's the default implementation for merge in Cost (Path i).

instance Cost (Path i) where

And... that's all we need! Running it on a sample graph gives us a shortest path.

Shortest paths (plural!)

Now what if we want all the shortest paths? Simple, we define a new cost!

data Paths i = Paths !Int [[i]]

The only difference with Path is that we store a bunch of them. But we compare them in the exact same way.

instance Eq  (Path i) where
  Path l1 _ == Path l2 _ = l1 == l2

instance Ord (Path i) where 
  Path l1 _ `compare` Path l2 _ = compare l1 l2

However, when we merge costs we make sure to keep all paths:

instance Cost (Paths i) where
  merge (Paths l xs) (Paths _ ys) = Paths l (xs ++ ys)

And... that's it!

Closing thoughts

Here we are. I hope this weekend obsession of mine was interesting to someone. It sure was quite surprising to me that an algorithm I was taught a while back could be applied in a more general context quite easily.

Disclaimer: I have done little to no research about whether this generalization has been discussed at large already. I did find a few research papers on routing algorithms over networks that give more algebraic structure to weights. I don't think they match one to one with what I describe here, because they seemed to be interested in more general probems. And I haven't found anything targeted at a larger non-scientific audience.

But as always, if you have any feedback, or any additional insight on what's discussed here, please reach out!

Feel free to react on reddit.

<!-- TODO: pre-compile katex to MathML only -->
<!-- <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.16.3/dist/katex.min.css" integrity="sha384-Juol1FqnotbkyZUT5Z7gUPjQ9gzlwCENvUZTpQBAPxtusdwFLRy382PSDx5UUJ4/" crossorigin="anonymous"> -->
<script defer src="https://cdn.jsdelivr.net/npm/katex@0.16.3/dist/katex.min.js" integrity="sha384-97gW6UIJxnlKemYavrqDHSX3SiygeOwIZhwyOKRfSaf0JWKRVj9hLASHgFTzT+0O" crossorigin="anonymous"></script>
<script type="module">
  const macros = {}
  document.querySelectorAll('.math').forEach(elem => {
    katex.render(elem.innerText, elem, {throwOnError: false, macros,
    displayMode: !(elem.classList.contains('inline')), output:'mathml'})
  })
</script>