4.4. Fold#
The map function gives us a way to individually transform each element of a list. The filter function gives us a way to individually decide whether to keep or throw away each element of a list. But both of those are really just looking at a single element at a time. What if we wanted to somehow combine all the elements of a list? That’s what the fold function is for. It turns out that there are two versions of it, which we’ll study in this section. But to start, let’s look at a related function—not actually in the standard library—that we call combine.
4.4.1. Combine#
Once more, let’s write two functions:
(** [sum lst] is the sum of all the elements of [lst]. *)
let rec sum = function
| [] -> 0
| h :: t -> h + sum t
let s = sum [1; 2; 3]
val sum : int list -> int = <fun>
val s : int = 6
(** [concat lst] is the concatenation of all the elements of [lst]. *)
let rec concat = function
| [] -> ""
| h :: t -> h ^ concat t
let c = concat ["a"; "b"; "c"]
val concat : string list -> string = <fun>
val c : string = "abc"
As when we went through similar exercises with map and filter, the functions share a great deal of common structure. The differences here are:
the case for the empty list returns a different initial value,
0
vs""
the case of a non-empty list uses a different operator to combine the head element with the result of the recursive call,
+
vs^
.
So can we apply the Abstraction Principle again? Sure! But this time we need to factor out two arguments: one for each of those two differences.
To start, let’s factor out only the initial value:
let rec sum' init = function
| [] -> init
| h :: t -> h + sum' init t
let sum = sum' 0
let rec concat' init = function
| [] -> init
| h :: t -> h ^ concat' init t
let concat = concat' ""
val sum' : int -> int list -> int = <fun>
val sum : int list -> int = <fun>
val concat' : string -> string list -> string = <fun>
val concat : string list -> string = <fun>
Now the only real difference left between sum'
and concat'
is the operator
used to combine the head with the recursive call on the tail. That operator can
also become an argument to a unified function we call combine
:
let rec combine op init = function
| [] -> init
| h :: t -> op h (combine op init t)
let sum = combine ( + ) 0
let concat = combine ( ^ ) ""
val combine : ('a -> 'b -> 'b) -> 'b -> 'a list -> 'b = <fun>
val sum : int list -> int = <fun>
val concat : string list -> string = <fun>
One way to think of combine
would be that:
the
[]
value in the list gets replaced byinit
, andeach
::
constructor gets replaced byop
.
For example, [a; b; c]
is just syntactic sugar for a :: (b :: (c :: []))
. So
if we replace []
with 0
and ::
with (+)
, we get a + (b + (c + 0))
.
And that would be the sum of the list.
Once more, the Abstraction Principle has led us to an amazingly simple and succinct expression of the computation.
4.4.2. Fold Right#
The combine
function is the idea underlying an actual OCaml library function.
To get there, we need to make a couple of changes to the implementation we have
so far.
First, let’s rename some of the arguments: we’ll change op
to f
to emphasize
that really we could pass in any function, not just a built-in operator like
+
. And we’ll change init
to acc
, which as usual stands for “accumulator”.
That yields:
let rec combine f acc = function
| [] -> acc
| h :: t -> f h (combine f acc t)
val combine : ('a -> 'b -> 'b) -> 'b -> 'a list -> 'b = <fun>
Second, let’s make an admittedly less well-motivated change. We’ll swap the
implicit list argument to combine
with the init
argument:
let rec combine' f lst acc = match lst with
| [] -> acc
| h :: t -> f h (combine' f t acc)
let sum lst = combine' ( + ) lst 0
let concat lst = combine' ( ^ ) lst ""
val combine' : ('a -> 'b -> 'b) -> 'a list -> 'b -> 'b = <fun>
val sum : int list -> int = <fun>
val concat : string list -> string = <fun>
It’s a little less convenient to code the function this way, because we no
longer get to take advantage of the function
keyword, nor of partial
application in defining sum
and concat
. But there’s no algorithmic change.
What we now have is the actual implementation of the standard library function
List.fold_right
. All we have left to do is change the function name
and add a manual type annotation:
let rec fold_right f lst (acc : 'acc) = match lst with
| [] -> acc
| h :: t -> f h (fold_right f t acc)
val fold_right : ('a -> 'acc -> 'acc) -> 'a list -> 'acc -> 'acc = <fun>
Why is this function called “fold right”? The intuition is that the way it works
is to “fold in” elements of the list from the right to the left, combining each
new element using the operator. For example, fold_right ( + ) [a; b; c] 0
results in evaluation of the expression a + (b + (c + 0))
. The parentheses
associate from the right-most subexpression to the left.
Tip
The manual type annotation is not necessary for a correct implementation of the function.
Its purpose is to provide a nicer type.
Without the annotation, the inferred type of fold_right
would be ('a -> 'b -> 'b) -> 'a list -> 'b -> 'b
, in which the compiler chooses 'b
as the type of the accumulator.
By manually annotating that argument with a self-descriptive name, we get the more readable type ('a -> 'acc -> 'acc) -> 'a list -> 'acc -> 'acc
.
4.4.3. Tail Recursion and Combine#
Neither fold_right
nor combine
are tail recursive: after the recursive call
returns, there is still work to be done in applying the function argument f
or
op
. Let’s go back to combine
and rewrite it to be tail recursive. All that
requires is to change the cons branch:
let rec combine_tr f acc = function
| [] -> acc
| h :: t -> combine_tr f (f acc h) t (* only real change *)
val combine_tr : ('a -> 'b -> 'a) -> 'a -> 'b list -> 'a = <fun>
(Careful readers will notice that the type of combine_tr
is different than the
type of combine
. We will address that soon.)
Now the function f
is applied to the head element h
and the accumulator
acc
before the recursive call is made, thus ensuring there’s no work
remaining to be done after the call returns. If that seems a little mysterious,
here’s a rewriting of the two functions that might help:
let rec combine f acc = function
| [] -> acc
| h :: t ->
let acc' = combine f acc t in
f h acc'
let rec combine_tr f acc = function
| [] -> acc
| h :: t ->
let acc' = f acc h in
combine_tr f acc' t
val combine : ('a -> 'b -> 'b) -> 'b -> 'a list -> 'b = <fun>
val combine_tr : ('a -> 'b -> 'a) -> 'a -> 'b list -> 'a = <fun>
Pay close attention to the definition of acc'
, the new accumulator, in each
of those versions:
In the original version, we procrastinate using the head element
h
. First, we combine all the remaining tail elements to getacc'
. Only then do we usef
to fold in the head. So the value passed as the initial value ofacc
turns out to be the same for every recursive invocation ofcombine
: it’s passed all the way down to where it’s needed, at the right-most element of the list, then used there exactly once.But in the tail recursive version, we “pre-crastinate” by immediately folding
h
in with the old accumulatoracc
. Then we fold that in with all the tail elements. So at each recursive invocation, the value passed as the argumentacc
can be different.
The tail recursive version of combine works just fine for summation (and concatenation, which we elide):
let sum = combine_tr ( + ) 0
let s = sum [1; 2; 3]
val sum : int list -> int = <fun>
val s : int = 6
But something possibly surprising happens with subtraction:
let sub = combine ( - ) 0
let s = sub [3; 2; 1]
let sub_tr = combine_tr ( - ) 0
let s' = sub_tr [3; 2; 1]
val sub : int list -> int = <fun>
val s : int = 2
val sub_tr : int list -> int = <fun>
val s' : int = -6
The two results are different!
With
combine
we compute3 - (2 - (1 - 0))
. First we fold in1
, then2
, then3
. We are processing the list from right to left, putting the initial accumulator at the far right.But with
combine_tr
we compute(((0 - 3) - 2) - 1)
. We are processing the list from left to right, putting the initial accumulator at the far left.
With addition it didn’t matter which order we processed the list, because addition is associative and commutative. But subtraction is not, so the two directions result in different answers.
Actually this shouldn’t be too surprising if we think back to when we made map
be tail recursive. Then, we discovered that tail recursion can cause us to
process the list in reverse order from the non-tail recursive version of the
same function. That’s what happened here.
4.4.4. Fold Left#
Our combine_tr
function is also in the standard library under the name
List.fold_left
:
let rec fold_left f (acc : 'acc) = function
| [] -> acc
| h :: t -> fold_left f (f acc h) t
let sum = fold_left ( + ) 0
let concat = fold_left ( ^ ) ""
val fold_left : ('acc -> 'a -> 'acc) -> 'acc -> 'a list -> 'acc = <fun>
val sum : int list -> int = <fun>
val concat : string list -> string = <fun>
We have once more succeeded in applying the Abstraction Principle.
4.4.5. Fold Left vs. Fold Right#
Let’s review the differences between fold_right
and fold_left
:
They combine list elements in opposite orders, as indicated by their names. Function
fold_right
combines from the right to the left, whereasfold_left
proceeds from the left to the right.Function
fold_left
is tail recursive whereasfold_right
is not.The types of the functions are different. In
fold_X
the accumulator argument goes to theX
of the list argument. That is a choice made by the standard library rather than a necessary implementation difference.
If you find it hard to keep track of the argument orders, the
ListLabels
module in the standard library can help. It uses
labeled arguments to give names to the combining operator (which it calls f
)
and the initial accumulator value (which it calls init
). Internally, the
implementation is actually identical to the List
module.
ListLabels.fold_left ~f:(fun x y -> x - y) ~init:0 [1; 2; 3];;
- : int = -6
ListLabels.fold_right ~f:(fun y x -> x - y) ~init:0 [1; 2; 3];;
- : int = -6
Notice how in the two applications of fold above, we are able to write the arguments in a uniform order thanks to their labels. However, we still have to be careful about which argument to the combining operator is the list element vs. the accumulator value.
4.4.6. A Digression on Labeled Arguments and Fold#
It’s possible to write our own version of the fold functions that would label the arguments to the combining operator, so we don’t even have to remember their order:
let rec fold_left ~op:(f: acc:'a -> elt:'b -> 'a) ~init:acc lst =
match lst with
| [] -> acc
| h :: t -> fold_left ~op:f ~init:(f ~acc:acc ~elt:h) t
let rec fold_right ~op:(f: elt:'a -> acc:'b -> 'b) lst ~init:acc =
match lst with
| [] -> acc
| h :: t -> f ~elt:h ~acc:(fold_right ~op:f t ~init:acc)
val fold_left : op:(acc:'a -> elt:'b -> 'a) -> init:'a -> 'b list -> 'a =
<fun>
val fold_right : op:(elt:'a -> acc:'b -> 'b) -> 'a list -> init:'b -> 'b =
<fun>
But those functions aren’t as useful as they might seem:
let s = fold_left ~op:( + ) ~init:0 [1;2;3]
File "[16]", line 1, characters 22-27:
1 | let s = fold_left ~op:( + ) ~init:0 [1;2;3]
^^^^^
Error: This expression has type int -> int -> int
but an expression was expected of type acc:'a -> elt:'b -> 'a
The problem is that the built-in +
operator doesn’t have labeled arguments,
so we can’t pass it in as the combining operator to our labeled functions.
We’d have to define our own labeled version of it:
let add ~acc ~elt = acc + elt
let s = fold_left ~op:add ~init:0 [1; 2; 3]
But now we have to remember that the ~acc
parameter to add
will become
the left-hand argument to ( + )
. That’s not really much of an improvement
over what we had to remember to begin with.
4.4.7. Using Fold to Implement Other Functions#
Folding is so powerful that we can write many other list functions in terms of
fold_left
or fold_right
. For example,
let length lst =
List.fold_left (fun acc _ -> acc + 1) 0 lst
let rev lst =
List.fold_left (fun acc x -> x :: acc) [] lst
let map f lst =
List.fold_right (fun x acc -> f x :: acc) lst []
let filter f lst =
List.fold_right (fun x acc -> if f x then x :: acc else acc) lst []
val length : 'a list -> int = <fun>
val rev : 'a list -> 'a list = <fun>
val map : ('a -> 'b) -> 'a list -> 'b list = <fun>
val filter : ('a -> bool) -> 'a list -> 'a list = <fun>
At this point it begins to become debatable whether it’s better to express the
computations above using folding or using the ways we have already seen. Even
for an experienced functional programmer, understanding what a fold does can
take longer than reading the naive recursive implementation. If you peruse the
source code of the standard library, you’ll see that none of the
List
module internally is implemented in terms of folding, which is perhaps
one comment on the readability of fold. On the other hand, using fold ensures
that the programmer doesn’t accidentally program the recursive traversal
incorrectly. And for a data structure that’s more complicated than lists, that
robustness might be a win.
4.4.8. Fold vs. Recursive vs. Library#
We’ve now seen three different ways for writing functions that manipulate lists:
directly as a recursive function that pattern matches against the empty list and against cons,
using
fold
functions, andusing other library functions.
Let’s try using each of those ways to solve a problem, so that we can appreciate them better.
Consider writing a function lst_and: bool list -> bool
, such that
lst_and [a1; ...; an]
returns whether all elements of the list are true
.
That is, it evaluates the same as a1 && a2 && ... && an
. When applied to an
empty list, it evaluates to true
.
Here are three possible ways of writing such a function. We give each way a slightly different function name for clarity.
let rec lst_and_rec = function
| [] -> true
| h :: t -> h && lst_and_rec t
let lst_and_fold =
List.fold_left (fun acc elt -> acc && elt) true
let lst_and_lib =
List.for_all (fun x -> x)
val lst_and_rec : bool list -> bool = <fun>
val lst_and_fold : bool list -> bool = <fun>
val lst_and_lib : bool list -> bool = <fun>
The worst-case running time of all three functions is linear in the length of the list. But:
The first function,
lst_and_rec
has the advantage that it need not process the entire list. It will immediately returnfalse
the first time they discover afalse
element in the list.The second function,
lst_and_fold
, will always process every element of the list.As for the third function
lst_and_lib
, according to the documentation ofList.for_all
, it returns(p a1) && (p a2) && ... && (p an)
. So likelst_and_rec
it need not process every element.