{
"cells": [
{
"cell_type": "markdown",
"id": "a0cd9c2a",
"metadata": {},
"source": [
"# Memoization\n",
"\n",
"In the previous section, we saw that the `Lazy` module memoizes the results of\n",
"computations, so that no time has to be wasted on recomputing them. Memoization\n",
"is a powerful technique for asymptotically speeding up simple recursive\n",
"algorithms, without having to change the way the algorithm works.\n",
"\n",
"Let's see apply the Abstraction Principle and invent a way to memoize *any*\n",
"function, so that the function only had to be evaluated once on any given input.\n",
"We'll end up using imperative data structures (arrays and hash tables) as part\n",
"of our solution.\n",
"\n",
"## Fibonacci\n",
"\n",
"Let's again consider the problem of computing the nth Fibonacci number.\n",
"The naive recursive implementation takes exponential time, because of the\n",
"recomputation of the same Fibonacci numbers over and over again:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "24275aa5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"val fib : int -> int = \n"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"let rec fib n = if n < 2 then 1 else fib (n - 1) + fib (n - 2)"
]
},
{
"cell_type": "markdown",
"id": "8e6966ef",
"metadata": {},
"source": [
"```{note}\n",
"To be precise, its running time turns out to be $O(\\phi^n)$, where $\\phi$ is the\n",
"golden ratio, $\\frac{1 + \\sqrt{5}}{2}$.\n",
"```\n",
"\n",
"If we record Fibonacci numbers as they are computed, we can avoid this redundant\n",
"work. The idea is that whenever we compute `f n`, we store it in a table indexed\n",
"by `n`. In this case the indexing keys are integers, so we can use implement\n",
"this table using an array:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "42bdce76",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"val fibm : int -> int = \n"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"let fibm n =\n",
" let memo : int option array = Array.make (n + 1) None in\n",
" let rec f_mem n =\n",
" match memo.(n) with\n",
" | Some result -> (* computed already *) result\n",
" | None ->\n",
" let result =\n",
" if n < 2 then 1 else f_mem (n - 1) + f_mem (n - 2)\n",
" in\n",
" (* record in table *)\n",
" memo.(n) <- Some result;\n",
" result\n",
" in\n",
" f_mem n"
]
},
{
"cell_type": "markdown",
"id": "67a9672f",
"metadata": {},
"source": [
"The function `f_mem` defined inside `fibm` contains the original recursive\n",
"algorithm, except before doing that calculation it first checks if the result\n",
"has already been computed and stored in the table in which case it simply\n",
"returns the result.\n",
"\n",
"How do we analyze the running time of this function? The time spent in a single\n",
"call to `f_mem` is $O(1)$ if we exclude the time spent in any recursive calls\n",
"that it happens to make. Now we look for a way to bound the total number of\n",
"recursive calls by finding some measure of the progress that is being made.\n",
"\n",
"A good choice of progress measure, not only here but also for many uses of\n",
"memoization, is the number of nonempty entries in the table (i.e. entries that\n",
"contain `Some n` rather than `None`). Each time `f_mem` makes the two recursive\n",
"calls it also increases the number of nonempty entries by one (filling in a\n",
"formerly empty entry in the table with a new value). Since the table has only\n",
"`n` entries, there can thus only be a total of $O(n)$ calls to `f_mem`, for a\n",
"total running time of $O(n)$ (because we established above that each call takes\n",
"$O(1)$ time). This speedup from memoization thus reduces the running time from\n",
"exponential to linear, a huge change---e.g., for $n=4$ the speedup from\n",
"memoization is more than a factor of a million!\n",
"\n",
"The key to being able to apply memoization is that there are common sub-problems\n",
"which are being solved repeatedly. Thus we are able to use some extra storage to\n",
"save on repeated computation.\n",
"\n",
"Although this code uses imperative constructs (specifically, array update), the\n",
"side effects are not visible outside the function `fibm`. So from a client's\n",
"perspective, `fibm` is functional. There's no need to mention the imperative\n",
"implementation (i.e., the benign side effects) that are used internally.\n",
"\n",
"## Memoization Using Higher-order Functions\n",
"\n",
"Now that we've seen an example of memoizing one function, let's use higher-order\n",
"functions to memoize any function. First, consider the case of memoizing a\n",
"non-recursive function `f`. In that case we simply need to create a hash table\n",
"that stores the corresponding value for each argument that `f` is called with\n",
"(and to memoize multi-argument functions we can use currying and uncurrying to\n",
"convert to a single argument function)."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "0c8b32db",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"val memo : ('a -> 'b) -> 'a -> 'b = \n"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"let memo f =\n",
" let h = Hashtbl.create 11 in\n",
" fun x ->\n",
" try Hashtbl.find h x\n",
" with Not_found ->\n",
" let y = f x in\n",
" Hashtbl.add h x y;\n",
" y"
]
},
{
"cell_type": "markdown",
"id": "99bbcfce",
"metadata": {},
"source": [
"For recursive functions, however, the recursive call structure needs to be\n",
"modified. This can be abstracted out independent of the function that is being\n",
"memoized:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4e945ff8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"val memo_rec : (('a -> 'b) -> 'a -> 'b) -> 'a -> 'b = \n"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"let memo_rec f =\n",
" let h = Hashtbl.create 16 in\n",
" let rec g x =\n",
" try Hashtbl.find h x\n",
" with Not_found ->\n",
" let y = f g x in\n",
" Hashtbl.add h x y;\n",
" y\n",
" in\n",
" g"
]
},
{
"cell_type": "markdown",
"id": "dcc61c53",
"metadata": {},
"source": [
"Now we can slightly rewrite the original `fib` function above using this general\n",
"memoization technique:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "fab6ea97",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"val fib_memo : int -> int = \n"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"let fib_memo =\n",
" let rec fib self n =\n",
" if n < 2 then 1 else self (n - 1) + self (n - 2)\n",
" in\n",
" memo_rec fib"
]
},
{
"cell_type": "markdown",
"id": "38be6c90",
"metadata": {},
"source": [
"## Just for Fun: Party Optimization\n",
"\n",
"Suppose we want to throw a party for a company whose org chart is a binary tree.\n",
"Each employee has an associated “fun value” and we want the set of invited\n",
"employees to have a maximum total fun value. However, no employee is fun if his\n",
"superior is invited, so we never invite two employees who are connected in the\n",
"org chart. (The less fun name for this problem is the maximum weight independent\n",
"set in a tree.) For an org chart with $n$ employees, there are $2^{n}$ possible \n",
"invitation lists, so the naive algorithm that compares the fun of every valid \n",
"invitation list takes exponential time.\n",
"\n",
"We can use memoization to turn this into a linear-time algorithm. We start by\n",
"defining a variant type to represent the employees. The int at each node is the\n",
"fun.\n",
"\n",
"```ocaml\n",
"type tree = Empty | Node of int * tree * tree\n",
"```\n",
"\n",
"Now, how can we solve this recursively? One important observation is that in any\n",
"tree, the optimal invitation list that doesn't include the root node will be the\n",
"union of optimal invitation lists for the left and right subtrees. And the\n",
"optimal invitation list that does include the root node will be the union of\n",
"optimal invitation lists for the left and right children that do not include\n",
"their respective root nodes. So it seems useful to have functions that optimize\n",
"the invite lists for the case where the root node is required to be invited, and\n",
"for the case where the root node is excluded. We'll call these two functions\n",
"party_in and party_out. Then the result of party is just the maximum of these\n",
"two functions:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "40404386",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"module Unmemoized :\n",
" sig\n",
" type tree = Empty | Node of int * tree * tree\n",
" val party : tree -> int\n",
" val party_in : tree -> int\n",
" val party_out : tree -> int\n",
" end\n"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"module Unmemoized = struct\n",
" type tree =\n",
" | Empty\n",
" | Node of int * tree * tree\n",
"\n",
" (* Returns optimum fun for t. *)\n",
" let rec party t = max (party_in t) (party_out t)\n",
"\n",
" (* Returns optimum fun for t assuming the root node of t\n",
" * is included. *)\n",
" and party_in t =\n",
" match t with\n",
" | Empty -> 0\n",
" | Node (v, left, right) -> v + party_out left + party_out right\n",
"\n",
" (* Returns optimum fun for t assuming the root node of t\n",
" * is excluded. *)\n",
" and party_out t =\n",
" match t with\n",
" | Empty -> 0\n",
" | Node (v, left, right) -> party left + party right\n",
"end"
]
},
{
"cell_type": "markdown",
"id": "12b36561",
"metadata": {},
"source": [
"This code has exponential running time. But notice that there are only $n$\n",
"possible distinct calls to party. If we change the code to memoize the results\n",
"of these calls, the performance will be linear in $n$. Here is a version that\n",
"memoizes the result of party and also computes the actual invitation lists.\n",
"Notice that this code memoizes results directly in the tree."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "59c7654f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"module Memoized :\n",
" sig\n",
" type tree =\n",
" Empty\n",
" | Node of int * string * tree * tree * (int * string list) option ref\n",
" val party : tree -> int * string list\n",
" val party_in : tree -> int * string list\n",
" val party_out : tree -> int * string list\n",
" end\n"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"module Memoized = struct\n",
" (* This version memoizes the optimal fun value for each tree node. It\n",
" also remembers the best invite list. Each tree node has the name of\n",
" the employee as a string. *)\n",
" type tree =\n",
" | Empty\n",
" | Node of\n",
" int * string * tree * tree * (int * string list) option ref\n",
"\n",
" let rec party t : int * string list =\n",
" match t with\n",
" | Empty -> (0, [])\n",
" | Node (_, name, left, right, memo) -> (\n",
" match !memo with\n",
" | Some result -> result\n",
" | None ->\n",
" let infun, innames = party_in t in\n",
" let outfun, outnames = party_out t in\n",
" let result =\n",
" if infun > outfun then (infun, innames)\n",
" else (outfun, outnames)\n",
" in\n",
" memo := Some result;\n",
" result)\n",
"\n",
" and party_in t =\n",
" match t with\n",
" | Empty -> (0, [])\n",
" | Node (v, name, l, r, _) ->\n",
" let lfun, lnames = party_out l and rfun, rnames = party_out r in\n",
" (v + lfun + rfun, name :: lnames @ rnames)\n",
"\n",
" and party_out t =\n",
" match t with\n",
" | Empty -> (0, [])\n",
" | Node (_, _, l, r, _) ->\n",
" let lfun, lnames = party l and rfun, rnames = party r in\n",
" (lfun + rfun, lnames @ rnames)\n",
"end"
]
},
{
"cell_type": "markdown",
"id": "e2b15492",
"metadata": {},
"source": [
"Why was memoization so effective for solving this problem? As with the Fibonacci\n",
"algorithm, we had the overlapping sub-problems property, in which the naive\n",
"recursive implementation called the function party many times with the same\n",
"arguments. Memoization saves all those calls. Further, the party optimization\n",
"problem has the property of optimal substructure, meaning that the optimal\n",
"answer to a problem is computed from optimal answers to sub-problems. Not all\n",
"optimization problems have this property. The key to using memoization\n",
"effectively for optimization problems is to figure out how to write a recursive\n",
"function that implements the algorithm and has two properties. Sometimes this\n",
"requires thinking carefully.\n",
"\n",
""
]
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"formats": "md:myst",
"text_representation": {
"extension": ".md",
"format_name": "myst",
"format_version": 0.13,
"jupytext_version": "1.10.3"
}
},
"kernelspec": {
"display_name": "OCaml",
"language": "OCaml",
"name": "ocaml-jupyter"
},
"language_info": {
"codemirror_mode": "text/x-ocaml",
"file_extension": ".ml",
"mimetype": "text/x-ocaml",
"name": "OCaml",
"nbconverter_exporter": null,
"pygments_lexer": "OCaml",
"version": "4.14.0"
},
"source_map": [
14,
34,
36,
48,
63,
104,
113,
119,
130,
135,
141,
173,
196,
204,
244
]
},
"nbformat": 4,
"nbformat_minor": 5
}