← back

Folding a matrix

12 December 2014  (updated 2016-03-12)

Write a function that takes a matrix and computes the sum of all its elements. Use only fold.

(where type matrix = int list list)

Well, the answer is this:

let sum_matrix = fold (fold (+)) 0;;

but how do we get here? I just guessed it from the types.

By calculation: Eta-reduction

Before we get to the intuition, lets examine how we can recover the one liner from this fill-in-the-blanks let template:

(* Alias the proper fold library function. *)
let fold = List.fold_left;;
(* # fold;;
- : ('a -> 'b -> 'a) -> 'a -> 'b list -> 'a = <fun> *)

let sum_matrix mtx =
  let base = 0 in
  (* sum one list *)
  let fx gx base lst =
    let gx = fun a x -> a + x in
    let base' = 0 in
    fold gx base' lst in
  fold fx base mtx;;

Notice that we can pass the outer foldl's accumulator as the accumulator of the inner foldl, since all we're doing is summing all numbers, essentially flatmapping the identity on the matrix list and summing a big list of ints.

let sum_matrix mtx =
  let base = 0 in
  (* sum one list *)
  let fx gx base lst =
    let gx = fun a x -> a + x in
    fold gx base lst in
  fold fx base mtx;;

See that gx function? That's a binary function that passes its arguments into another binary function, and returns the result of applying that function to the arguments -- an identity lambda abstraction.

Even the types are the same:

gx  : int -> int -> int
(+) : int -> int -> int

We can just drop in (+) instead of writing an explicit \(\), which we do using sections.

let sum_matrix mtx =
  let base = 0 in
  (* sum one list *)
  let fx base lst =
    fold (+) base lst in
  fold fx base mtx;;

Similarly, notice that the last two params of fx flow directly into fold, so they can be removed.

ocaml let sum_matrix mtx = let base = 0 in (* sum one list *) let fx = fold (+) in fold fx base mtx;;

Remove the lets,

let sum_matrix mtx = fold (fold (+)) 0 mtx;;

And finally,

let sum_matrix = fold (fold (+)) 0;;

Intuitively, fold expects its last input to be a list, so our definition of sum_matrix should be a partially applied fold.

By intuition: partial application

Attempt 1

You already know that fold (+) 0 lst sums a list, and since a matrix is a list of list of ints, it follows that there needs to be two folds over the structure of the matrix. We can see this by looking at the types, which show a high level picture of the program.

fold : ('a -> 'b -> 'a) -> 'a -> 'b list -> 'a = <fun>
fold (+) : int -> int list -> int = <fun>

We're summing a matrix with type int list list, so we seek a solution of type sum_matrix : int list list -> int

If we gave fold the higher order function fold (+), what would the resulting type be? Just substitute into the type parameters:

'a -> int
'b -> int list

=> (int -> int list -> int) -> int -> int list list -> int

The resulting type of fold (fold (+)) is:

fold (fold (+)) : int -> int list list -> int = <fun>

Passing in the accumulator 0 fills in the first hole and gives us the function we want.

fold (fold (+)) 0 : int list list -> int = <fun>

Attempt 2

Alternatively, you can ask yourself a few questions about the structure of the data and the process:

  1. What are we folding over? A list of list of ints. So at each step of the fold we process a list of ints.
  2. What do we want from this list of ints? The sum.
  3. What do we do to the sum? Add it to the accumulator.

So write

let sum_matrix =
  let f = fun acc lst -> (fold (+) 0 lst) + acc in
fold f 0

We get the final form after a bit of calculation.

Attempt 3

We could instead have reasoned about the process using higher order functions. The sum of a list of lists is the sum of the result of map sum over it.

let sum_matrix mtx = sum (map sum mtx);;
(* => *)
let sum_matrix mtx = fold (+) 0 (map (fold (+) 0) mtx);;

(In case you didn't notice, sum = fold (+) 0)

But the question asked for fold only, so we'll either have to throw this out, or eliminate map.

Let's try eliminating map. First, we need some algebraic machinery:

Theorem (fold-map fusion law12).

  foldl g  b . map f = foldl g' b
  where g' b x = g b (f x)

Intuitively, this says that we can apply the function f on the elements simultaneously as we are running fold, without changing the order of inputs or outputs.

Thus,

fold (+) 0 (map (fold (+) 0) mtx)

<=> { fold-map fusion }
fold f 0 mtx
where f b e = (+) b (foldl (+) 0 e)

<=> { infix }
fold f 0 mtx
where f b e = b + (foldl (+) 0 e)

<=> { 0 is identity, replace with b (see section on Monoids) }
{ i.e. b + (foldl (+) 0 e)
    => b + e1 + e2 + ... + 0
    => 0 + e1 + e2 + ... + b
    => 0 + (foldl (+) b e)
    => (foldl (+) b e) }

fold f 0 mtx
where f b e = foldl (+) b e

<=> { eta-reduction }
fold f 0 mtx
where f = foldl (+)

<=> { substitution }
fold (fold (+)) 0 mtx

And we have what we want.

Attempt 4 (bad)

But, if we did this, it wouldn't typecheck.

let sum_matrix = fold (fold (+) 0) 0;;

Why? We're passing implicitly the argument matrix into sum_matrix to get this Eta-expanded version:

let sum_matrix mtx = fold (fold (+) 0) 0 mtx;;

Then, the outer fold gives the inner fold (the higher order argument to fold) two arguments: the acc and the list to fold over. But wait, one of the slots that the function is meant to have is filled! If we expand again, it becomes obvious:

let sum_matrix mtx =
  let fx lst = fold (+) 0 lst in
  fold fx 0 mtx;;

fx is meant to take two params, but its only got one open parameter. Of course, this doesn't typecheck:

# let sum_matrix = fold (fold (+) 0) 0;;
Error: This expression has type int list -> int
       but an expression was expected of type int list -> 'a -> int list
       Type int is not compatible with type 'a -> int list

The higher order function to fold accepts always accepts two parameters, so we must make sure that the inner fold accepts two inputs. Because we're doing partial application, this is equivalent to having two parameters not specified, which yields a partially applied fold function that accepts not three, but two parameters.

(aside)

This "program derivation from types" approach is in my opinion one of the main "mind bending" ways of thinking you learn through FP. I tend to think of programming like Lego, where you join together bricks with different shapes of connectors into bigger and bigger blocks, with the guidance of the type system. Most of the time you can guess your way through! Or cooler yet, use type holes in Emacs for an interactive, Coq-like programming experience, where you just ask the type system for the answer.

(end of aside)

Structure of lists

This formulation has additional advantages to being pretty: it reveals a pattern for flattening and summing nested lists, based on the structure of lists. What if I wanted to sum a list of list of list of ints?

let sum_int_3 = fold (fold (fold (+))) 0;;
# sum_int_3 [[[1;1];[1;1]];[[1;1]]];;
- : int = 6

What about 4-th level nesting, like lst : int list list list list?

let sum_int_4 = fold (fold (fold (fold (+)))) 0;;

And so on.3


  1. Richard Bird, Pearls of Functional Algorithm Design, pg. 255.

  2. Graham Hutton, A tutorial on the universality and expressiveness of fold

  3. Though I couldn't figure out how to do pattern matching on parameterized types (e.g. differentiating between int list, 'a list) in OCaml --- is this even possible? --- it's a simple tree recursive procedure in Scheme. Note that this allows lists with arbitrary depths (e.g. ((1 2 3) (((((4)))))) is treated as a proper matrix, though this is obviously wrong).

    (define (sum-lst-n lst)
      (cond [(empty? lst) 0]
            ;; distinguish between list of int and nested lists
            [(list? (car lst))
             (+ (sum-lst-n (car lst))
                (sum-lst-n (cdr lst)))]
            [else (foldl + 0 lst)]))
    
    ;; test
    > (define lst '(((1) (1 2 3)) (((1 2 3) (((5))))))
    > (sum-lst-n lst)
    18

    If Racket had ML-like support for currying, then we can implement this using the pattern we derived.