{
"cells": [
{
"cell_type": "markdown",
"id": "135d81a8",
"metadata": {},
"source": [
"# Fold\n",
"\n",
"The map function gives us a way to individually transform each element of a\n",
"list. The filter function gives us a way to individually decide whether to\n",
"keep or throw away each element of a list. But both of those are really just\n",
"looking at a single element at a time. What if we wanted to somehow combine all\n",
"the elements of a list? That's what the *fold* function is for. It turns out\n",
"that there are two versions of it, which we'll study in this section. But to\n",
"start, let's look at a related function—not actually in the standard\n",
"library—that we call *combine*.\n",
"\n",
"## Combine\n",
"\n",
"{{ video_embed | replace(\"%%VID%%\", \"uYJVwW2BFPg\")}}\n",
"\n",
"Once more, let's write two functions:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "25b5ebf5",
"metadata": {},
"outputs": [],
"source": [
"(** [sum lst] is the sum of all the elements of [lst]. *)\n",
"let rec sum = function\n",
" | [] -> 0\n",
" | h :: t -> h + sum t\n",
"\n",
"let s = sum [1; 2; 3]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eeb35eb5",
"metadata": {},
"outputs": [],
"source": [
"(** [concat lst] is the concatenation of all the elements of [lst]. *)\n",
"let rec concat = function\n",
" | [] -> \"\"\n",
" | h :: t -> h ^ concat t\n",
"\n",
"let c = concat [\"a\"; \"b\"; \"c\"]"
]
},
{
"cell_type": "markdown",
"id": "db9e37fd",
"metadata": {},
"source": [
"As when we went through similar exercises with map and filter, the functions\n",
"share a great deal of common structure. The differences here are:\n",
"\n",
"* the case for the empty list returns a different initial value, `0` vs `\"\"`\n",
"\n",
"* the case of a non-empty list uses a different operator to combine the head\n",
" element with the result of the recursive call, `+` vs `^`.\n",
"\n",
"So can we apply the Abstraction Principle again? Sure! But this time we need to\n",
"factor out *two* arguments: one for each of those two differences.\n",
"\n",
"To start, let's factor out only the initial value:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5addc783",
"metadata": {},
"outputs": [],
"source": [
"let rec sum' init = function\n",
" | [] -> init\n",
" | h :: t -> h + sum' init t\n",
"\n",
"let sum = sum' 0\n",
"\n",
"let rec concat' init = function\n",
" | [] -> init\n",
" | h :: t -> h ^ concat' init t\n",
"\n",
"let concat = concat' \"\""
]
},
{
"cell_type": "markdown",
"id": "21f69614",
"metadata": {},
"source": [
"Now the only real difference left between `sum'` and `concat'` is the operator\n",
"used to combine the head with the recursive call on the tail. That operator can\n",
"also become an argument to a unified function we call `combine`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a97b9939",
"metadata": {},
"outputs": [],
"source": [
"let rec combine op init = function\n",
" | [] -> init\n",
" | h :: t -> op h (combine op init t)\n",
"\n",
"let sum = combine ( + ) 0\n",
"let concat = combine ( ^ ) \"\""
]
},
{
"cell_type": "markdown",
"id": "173b7693",
"metadata": {},
"source": [
"One way to think of `combine` would be that:\n",
"\n",
"- the `[]` value in the list gets replaced by `init`, and\n",
"\n",
"- each `::` constructor gets replaced by `op`.\n",
"\n",
"For example, `[a; b; c]` is just syntactic sugar for `a :: (b :: (c :: []))`. So\n",
"if we replace `[]` with `0` and `::` with `(+)`, we get `a + (b + (c + 0))`.\n",
"And that would be the sum of the list.\n",
"\n",
"Once more, the Abstraction Principle has led us to an amazingly simple and\n",
"succinct expression of the computation.\n",
"\n",
"## Fold Right\n",
"\n",
"{{ video_embed | replace(\"%%VID%%\", \"WKKkIGncRn8\")}}\n",
"\n",
"The `combine` function is the idea underlying an actual OCaml library function.\n",
"To get there, we need to make a couple of changes to the implementation we have\n",
"so far.\n",
"\n",
"First, let's rename some of the arguments: we'll change `op` to `f` to emphasize\n",
"that really we could pass in any function, not just a built-in operator like\n",
"`+`. And we'll change `init` to `acc`, which as usual stands for \"accumulator\".\n",
"That yields:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "722b2f55",
"metadata": {},
"outputs": [],
"source": [
"let rec combine f acc = function\n",
" | [] -> acc\n",
" | h :: t -> f h (combine f acc t)"
]
},
{
"cell_type": "markdown",
"id": "74cf31ef",
"metadata": {},
"source": [
"Second, let's make an admittedly less well-motivated change. We'll swap the\n",
"implicit list argument to `combine` with the `init` argument:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1d1d3d74",
"metadata": {},
"outputs": [],
"source": [
"let rec combine' f lst acc = match lst with\n",
" | [] -> acc\n",
" | h :: t -> f h (combine' f t acc)\n",
"\n",
"let sum lst = combine' ( + ) lst 0\n",
"let concat lst = combine' ( ^ ) lst \"\""
]
},
{
"cell_type": "markdown",
"id": "a722bf7c",
"metadata": {},
"source": [
"It's a little less convenient to code the function this way, because we no\n",
"longer get to take advantage of the `function` keyword, nor of partial\n",
"application in defining `sum` and `concat`. But there's no algorithmic change.\n",
"\n",
"What we now have is the actual implementation of the standard library function\n",
"`List.fold_right`. All we have left to do is change the function name:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "47493819",
"metadata": {},
"outputs": [],
"source": [
"let rec fold_right f lst acc = match lst with\n",
" | [] -> acc\n",
" | h :: t -> f h (fold_right f t acc)"
]
},
{
"cell_type": "markdown",
"id": "53352a95",
"metadata": {},
"source": [
"Why is this function called \"fold right\"? The intuition is that the way it works\n",
"is to \"fold in\" elements of the list from the right to the left, combining each\n",
"new element using the operator. For example, `fold_right ( + ) [a; b; c] 0`\n",
"results in evaluation of the expression `a + (b + (c + 0))`. The parentheses\n",
"associate from the right-most subexpression to the left.\n",
"\n",
"## Tail Recursion and Combine\n",
"\n",
"Neither `fold_right` nor `combine` are tail recursive: after the recursive call\n",
"returns, there is still work to be done in applying the function argument `f` or\n",
"`op`. Let's go back to `combine` and rewrite it to be tail recursive. All that\n",
"requires is to change the cons branch:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f62b958f",
"metadata": {},
"outputs": [],
"source": [
"let rec combine_tr f acc = function\n",
" | [] -> acc\n",
" | h :: t -> combine_tr f (f acc h) t (* only real change *)"
]
},
{
"cell_type": "markdown",
"id": "1eb27d0f",
"metadata": {},
"source": [
"(Careful readers will notice that the type of `combine_tr` is different than the\n",
"type of `combine`. We will address that soon.)\n",
"\n",
"Now the function `f` is applied to the head element `h` and the accumulator\n",
"`acc` *before* the recursive call is made, thus ensuring there's no work\n",
"remaining to be done after the call returns. If that seems a little mysterious,\n",
"here's a rewriting of the two functions that might help:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "71153bcc",
"metadata": {},
"outputs": [],
"source": [
"let rec combine f acc = function\n",
" | [] -> acc\n",
" | h :: t ->\n",
" let acc' = combine f acc t in\n",
" f h acc'\n",
"\n",
"let rec combine_tr f acc = function\n",
" | [] -> acc\n",
" | h :: t ->\n",
" let acc' = f acc h in\n",
" combine_tr f acc' t"
]
},
{
"cell_type": "markdown",
"id": "00ee5359",
"metadata": {},
"source": [
"Pay close attention to the definition of `acc'`, the new accumulator, in each\n",
"of those versions:\n",
"\n",
"- In the original version, we procrastinate using the head element `h`. First,\n",
" we combine all the remaining tail elements to get `acc'`. Only then do we use\n",
" `f` to fold in the head. So the value passed as the initial value of `acc`\n",
" turns out to be the same for every recursive invocation of `combine`: it's\n",
" passed all the way down to where it's needed, at the right-most element of the\n",
" list, then used there exactly once.\n",
"\n",
"- But in the tail recursive version, we \"pre-crastinate\" by immediately folding\n",
" `h` in with the old accumulator `acc`. Then we fold that in with all the tail\n",
" elements. So at each recursive invocation, the value passed as the argument\n",
" `acc` can be different.\n",
"\n",
"The tail recursive version of combine works just fine for summation (and\n",
"concatenation, which we elide):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "297cca07",
"metadata": {},
"outputs": [],
"source": [
"let sum = combine_tr ( + ) 0\n",
"let s = sum [1; 2; 3]"
]
},
{
"cell_type": "markdown",
"id": "adbf97a6",
"metadata": {},
"source": [
"But something possibly surprising happens with subtraction:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6c25fec7",
"metadata": {},
"outputs": [],
"source": [
"let sub = combine ( - ) 0\n",
"let s = sub [3; 2; 1]\n",
"\n",
"let sub_tr = combine_tr ( - ) 0\n",
"let s' = sub_tr [3; 2; 1]"
]
},
{
"cell_type": "markdown",
"id": "316f723f",
"metadata": {},
"source": [
"The two results are different!\n",
"\n",
"- With `combine` we compute `3 - (2 - (1 - 0))`. First we fold in `1`, then `2`,\n",
" then `3`. We are processing the list from right to left, putting the initial\n",
" accumulator at the far right.\n",
"\n",
"- But with `combine_tr` we compute `(((0 - 3) - 2) - 1)`. We are processing the\n",
" list from left to right, putting the initial accumulator at the far left.\n",
"\n",
"With addition it didn't matter which order we processed the list, because\n",
"addition is associative and commutative. But subtraction is not, so the two\n",
"directions result in different answers.\n",
"\n",
"Actually this shouldn't be too surprising if we think back to when we made `map`\n",
"be tail recursive. Then, we discovered that tail recursion can cause us to\n",
"process the list in reverse order from the non-tail recursive version of the\n",
"same function. That's what happened here.\n",
"\n",
"## Fold Left\n",
"\n",
"Our `combine_tr` function is also in the standard library under the name\n",
"`List.fold_left`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "40ae9ab4",
"metadata": {},
"outputs": [],
"source": [
"let rec fold_left f acc = function\n",
" | [] -> acc\n",
" | h :: t -> fold_left f (f acc h) t\n",
"\n",
"let sum = fold_left ( + ) 0\n",
"let concat = fold_left ( ^ ) \"\""
]
},
{
"cell_type": "markdown",
"id": "28e0ae54",
"metadata": {},
"source": [
"We have once more succeeded in applying the Abstraction Principle.\n",
"\n",
"## Fold Left vs. Fold Right\n",
"\n",
"Let's review the differences between `fold_right` and `fold_left`:\n",
"\n",
"- They combine list elements in opposite orders, as indicated by their names.\n",
" Function `fold_right` combines from the right to the left, whereas `fold_left`\n",
" proceeds from the left to the right.\n",
"\n",
"- Function `fold_left` is tail recursive whereas `fold_right` is\n",
" not.\n",
"\n",
"- The types of the functions are different.\n",
"\n",
"Regarding that final point, it can be hard to remember what those types are!\n",
"Luckily we can always ask the toplevel:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "64859ee2",
"metadata": {},
"outputs": [],
"source": [
"List.fold_left;;\n",
"List.fold_right;;"
]
},
{
"cell_type": "markdown",
"id": "99ad9ada",
"metadata": {},
"source": [
"To understand those types, look for the list argument in each one of them. That\n",
"tells you the type of the values in the list. Then look for the type of the\n",
"return value; that tells you the type of the accumulator. From there you can\n",
"work out everything else.\n",
"\n",
"* In `fold_left`, the list argument is of type `'b list`, so the list contains\n",
" values of type `'b`. The return type is `'a`, so the accumulator has type\n",
" `'a`. Knowing that, we can figure out that the second argument is the initial\n",
" value of the accumulator (because it has type `'a`). And we can figure out\n",
" that the first argument, the combining operator, takes as its own first\n",
" argument an accumulator value (because it has type `'a`), as its own second\n",
" argument a list element (because it has type `'b`), and returns a new\n",
" accumulator value.\n",
"\n",
"* In `fold_right`, the list argument is of type `'a list`, so the list contains\n",
" values of type `'a`. The return type is `'b`, so the accumulator has type\n",
" `'b`. Knowing that, we can figure out that the third argument is the initial\n",
" value of the accumulator (because it has type `'b`). And we can figure out\n",
" that the first argument, the combining operator, takes as its own second\n",
" argument an accumulator value (because it has type `'b`), as its own first\n",
" argument a list element (because it has type `'a`), and returns a new\n",
" accumulator value.\n",
"\n",
"```{tip}\n",
"You might wonder why the argument orders are different between the two `fold`\n",
"functions. Good question. Other libraries do in fact use different argument\n",
"orders. One way to remember it for OCaml is that in `fold_X` the accumulator\n",
"argument goes to the `X` of the list argument.\n",
"```\n",
"\n",
"If you find it hard to keep track of all these argument orders, the\n",
"[`ListLabels` module][listlabels] in the standard library can help. It uses\n",
"labeled arguments to give names to the combining operator (which it calls `f`)\n",
"and the initial accumulator value (which it calls `init`). Internally, the\n",
"implementation is actually identical to the `List` module."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8c750820",
"metadata": {},
"outputs": [],
"source": [
"ListLabels.fold_left;;\n",
"ListLabels.fold_left ~f:(fun x y -> x - y) ~init:0 [1;2;3];;"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16f5bad2",
"metadata": {},
"outputs": [],
"source": [
"ListLabels.fold_right;;\n",
"ListLabels.fold_right ~f:(fun y x -> x - y) ~init:0 [1;2;3];;"
]
},
{
"cell_type": "markdown",
"id": "725b8c8b",
"metadata": {},
"source": [
"Notice how in the two applications of fold above, we are able to write the\n",
"arguments in a uniform order thanks to their labels. However, we still have to\n",
"be careful about which argument to the combining operator is the list element\n",
"vs. the accumulator value.\n",
"\n",
"[listlabels]: https://ocaml.org/api/ListLabels.html\n",
"\n",
"## A Digression on Labeled Arguments and Fold\n",
"\n",
"It's possible to write our own version of the fold functions that would label\n",
"the arguments to the combining operator, so we don't even have to remember their\n",
"order:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "602cedf5",
"metadata": {},
"outputs": [],
"source": [
"let rec fold_left ~op:(f: acc:'a -> elt:'b -> 'a) ~init:acc lst =\n",
" match lst with\n",
" | [] -> acc\n",
" | h :: t -> fold_left ~op:f ~init:(f ~acc:acc ~elt:h) t\n",
"\n",
"let rec fold_right ~op:(f: elt:'a -> acc:'b -> 'b) lst ~init:acc =\n",
" match lst with\n",
" | [] -> acc\n",
" | h :: t -> f ~elt:h ~acc:(fold_right ~op:f t ~init:acc)"
]
},
{
"cell_type": "markdown",
"id": "c13f297a",
"metadata": {},
"source": [
"But those functions aren't as useful as they might seem:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "23e0011b",
"metadata": {
"tags": [
"raises-exception"
]
},
"outputs": [],
"source": [
"let s = fold_left ~op:( + ) ~init:0 [1;2;3]"
]
},
{
"cell_type": "markdown",
"id": "ce220bf1",
"metadata": {},
"source": [
"The problem is that the built-in `+` operator doesn't have labeled arguments,\n",
"so we can't pass it in as the combining operator to our labeled functions.\n",
"We'd have to define our own labeled version of it:\n",
"\n",
"```\n",
"let add ~acc ~elt = acc + elt\n",
"let s = fold_left ~op:add ~init:0 [1; 2; 3]\n",
"```\n",
"\n",
"But now we have to remember that the `~acc` parameter to `add` will become\n",
"the left-hand argument to `( + )`. That's not really much of an improvement\n",
"over what we had to remember to begin with.\n",
"\n",
"## Using Fold to Implement Other Functions\n",
"\n",
"Folding is so powerful that we can write many other list functions in terms of\n",
"`fold_left` or `fold_right`. For example,"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f814ad33",
"metadata": {},
"outputs": [],
"source": [
"let length lst =\n",
" List.fold_left (fun acc _ -> acc + 1) 0 lst\n",
"\n",
"let rev lst =\n",
" List.fold_left (fun acc x -> x :: acc) [] lst\n",
"\n",
"let map f lst =\n",
" List.fold_right (fun x acc -> f x :: acc) lst []\n",
"\n",
"let filter f lst =\n",
" List.fold_right (fun x acc -> if f x then x :: acc else acc) lst []"
]
},
{
"cell_type": "markdown",
"id": "9ae82764",
"metadata": {},
"source": [
"At this point it begins to become debatable whether it's better to express the\n",
"computations above using folding or using the ways we have already seen. Even\n",
"for an experienced functional programmer, understanding what a fold does can\n",
"take longer than reading the naive recursive implementation. If you peruse the\n",
"[source code of the standard library][list-src], you'll see that none of the\n",
"`List` module internally is implemented in terms of folding, which is perhaps\n",
"one comment on the readability of fold. On the other hand, using fold ensures\n",
"that the programmer doesn't accidentally program the recursive traversal\n",
"incorrectly. And for a data structure that's more complicated than lists, that\n",
"robustness might be a win.\n",
"\n",
"[list-src]: https://github.com/ocaml/ocaml/blob/trunk/stdlib/list.ml\n",
"\n",
"## Fold vs. Recursive vs. Library\n",
"\n",
"We've now seen three different ways for writing functions that manipulate lists:\n",
"\n",
"- directly as a recursive function that pattern matches against the empty list\n",
" and against cons,\n",
"- using `fold` functions, and\n",
"- using other library functions.\n",
"\n",
"Let's try using each of those ways to solve a problem, so that we can appreciate\n",
"them better.\n",
"\n",
"Consider writing a function `lst_and: bool list -> bool`, such that\n",
"`lst_and [a1; ...; an]` returns whether all elements of the list are `true`.\n",
"That is, it evaluates the same as `a1 && a2 && ... && an`. When applied to an\n",
"empty list, it evaluates to `true`.\n",
"\n",
"Here are three possible ways of writing such a function. We give each way a\n",
"slightly different function name for clarity."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24a37bc1",
"metadata": {},
"outputs": [],
"source": [
"let rec lst_and_rec = function\n",
" | [] -> true\n",
" | h :: t -> h && lst_and_rec t\n",
"\n",
"let lst_and_fold =\n",
"\tList.fold_left (fun acc elt -> acc && elt) true\n",
"\n",
"let lst_and_lib =\n",
"\tList.for_all (fun x -> x)"
]
},
{
"cell_type": "markdown",
"id": "0e96cbab",
"metadata": {},
"source": [
"The worst-case running time of all three functions is linear in the length of\n",
"the list. But:\n",
"\n",
"- The first function, `lst_and_rec` has the advantage that it need not process\n",
" the entire list. It will immediately return `false` the first time they\n",
" discover a `false` element in the list.\n",
"\n",
"- The second function, `lst_and_fold`, will always process every element of the\n",
" list.\n",
"\n",
"- As for the third function `lst_and_lib`, according to the documentation of\n",
" `List.for_all`, it returns `(p a1) && (p a2) && ... && (p an)`. So like\n",
" `lst_and_rec` it need not process every element."
]
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"formats": "md:myst",
"text_representation": {
"extension": ".md",
"format_name": "myst",
"format_version": 0.13,
"jupytext_version": "1.10.3"
}
},
"kernelspec": {
"display_name": "OCaml",
"language": "OCaml",
"name": "ocaml-jupyter"
},
"source_map": [
14,
33,
42,
49,
63,
75,
79,
86,
114,
118,
123,
130,
139,
143,
158,
162,
172,
184,
204,
207,
211,
217,
242,
249,
269,
272,
310,
315,
318,
333,
343,
347,
350,
370,
382,
417,
427
]
},
"nbformat": 4,
"nbformat_minor": 5
}