How Tail Call Optimization Works

Most undergraduate computer sciences courses teach students about tail call optimization (TCO), and even if you don't have a formal computer science background the concept is talked about enough that you might be familiar with it anyway, especially if you've ever done any functional programming. However, I think the way TCO is normally taught is very confusing, because it's normally taught in the context of recursion. It's taught this way because without TCO many recursive functions can blow up the stack causing a stack overflow. Therefore by teaching people about TCO in the context of recursion, you can teach them why optimizing compilers (or interpreters) can run tail recursive code efficiently and without causing a stack overflow.

However, the recursion case for TCO is actually not the norm: in fact, if you're writing code in C, C++, or any most other languages with an optimizing compiler you're almost certainly having TCO applied all over your programs even if they don't use any recursion whatsoever. Understanding the non-recursive case of TCO is actually a lot simpler, and if you understand the non-recursive case you realize that there's actually nothing special whatsoever about how TCO is applied to recursive functions.

How Function Calls Work

First I'm going to give a refresher about how function calls work on x86. Note that basically every ISA has this same model, including ARM, so nothing I say here is x86 specific except for the registers and x86 mnemonics I'll use.

Your computer has a bunch of registers, and one of them is program counter (PC), which is also called the instruction pointer in x86. Every time the CPU executes an instruction it automatically increments the PC so that it points to the next instruction. For example, suppose we execute a nop instruction in x86, which is a no-op instruction that's encoded using a single byte. After executing the nop instruction the PC will be advanced by 1, because that's the size of nop and advancing it by 1 causes it to point to the next instruction. There are some special instructions that modify the PC in a different way: jmp (i.e. "jump"), call, and ret (i.e. "return"). Typically jmp and call are called with an immediate operand which is the offset in number of bytes to adjust the PC by, and ret has no arguments.

We'll start with jmp, because that's the simpler of these instructions. An instruction like jmp $15 means "set the PC to be 15 more than it's current value", in other words it jumps forward 15 bytes worth of instructions (I'm using AT&T syntax here, which is why the 15 has a $ in front---this denotes it as an immediate value, rather than a memory address). You can also jump backwards if the operand is a negative value. Normally this instruction is used to implement basic flow control constructs like if, else, for, etc.

The cousin of jmp is call, and it works exactly the same was as jmp except it also has the side effect of pushing the PC onto the stack. More precisely, on x86 the PC is adjusted before an instruction executes, so when call pushes the PC onto the stack it's actually pushing the address of the instruction after the call instruction. The PC is pushed onto the stack so the called function can return back to the right spot. To return from a function, a function executes ret which pops the top value on the stack and then jumps back to that location. Both call and ret are simply convenience instructions and act as if executing a push + jmp or pop + jmp.

Tail Call Optimization

Suppose we have a function like this:

void foo(int x) {
  int y = x * 2;
  printf("x = %d, y = %d\n", x, y);
}

If you compile this code with gcc -O1 it will generate assembly for this function the obvious way. Here's what it looks like:

(gdb) disas foo
Dump of assembler code for function foo:
   0x0000000000000000 <+0>:	sub    $0x8,%rsp
   0x0000000000000004 <+4>:	mov    %edi,%esi
   0x0000000000000006 <+6>:	lea    (%rdi,%rdi,1),%edx
   0x0000000000000009 <+9>:	mov    $0x0,%edi
   0x000000000000000e <+14>:	mov    $0x0,%eax
   0x0000000000000013 <+19>:	call   0x18 <foo+24>
   0x0000000000000018 <+24>:	add    $0x8,%rsp
   0x000000000000001c <+28>:	ret

I'm not going to explain every line of this but I'll explain the important bits. The first instruction sub $0x8,%rsp reserves 8 bytes of space on the stack for the variable y from our C code. The penultimate instruction is add $0x8,%rsp which adjusts the stack pointer back to its original value, and the last instruction is ret which returns to the caller in the way described earlier (i.e. by popping the return address from the stack and jumping to it). In general, before calling ret a function always needs to undo any adjustments it made to the stack pointer (%rsp) so that ret will pop the right value before returning. There's also a funny looking call instruction in this example; this is actually a bogus operand to call because the compiler just sees the forward declaration of printf() and doesn't know where it actually is yet. In the final linking step the operand to the call instruction will be updated so that it calls into libc.

Now let's look at what is generated when compiling with gcc -O2:

(gdb) disas foo
Dump of assembler code for function foo:
   0x0000000000000000 <+0>:	mov    %edi,%esi
   0x0000000000000002 <+2>:	lea    (%rdi,%rdi,1),%edx
   0x0000000000000005 <+5>:	xor    %eax,%eax
   0x0000000000000007 <+7>:	mov    $0x0,%edi
   0x000000000000000c <+12>:	jmp    0x11

In this example we see that the compiler optimized away the stack code by using lea to load the value of y directly into the relevant register (in this case %edx) instead of using space on the stack, but this isn't really relevant to TCO. The interesting bit related to TCO is the end of the function. Where before we had a call, now we have a jmp; and also notable, there's no more ret instruction. Once again the operand to jmp is just a dummy value here that will be replaced at link time by the linker.

This is tail call optimization. Tail call optimization happens when the compiler transforms a call immediately followed by a ret into a single jmp. This transformation saves one instruction, and more importantly it eliminates the implicit push/pop from the stack done by call and ret. And your compiler does it all the time, not just for recursive function calls. In general TCO will be applied any time the last instruction of a function is another function call. So how does it work?

In the optimized version the jmp instruction will jump directly to printf, without pushing the PC onto the stack. Imagine this is part of a larger program, and some other subroutine called foo. When foo is entered the address of the caller will be at the top of the stack. When foo executes the jmp at the end, it will jump directly to the printf code, and the address of the caller of foo will still be on top of the stack. The printf function itself has its own ret that it uses to return. Since foo didn't push its own address onto the stack, when printf executes its ret instruction it will actually pop the value from the stack that corresponds to the address of the caller of foo.

Recursive Tail Call Optimization

If a function makes a recursive tail call, TCO will be applied exactly as above. A simple compiler can actually treat recursive TCO exactly the same as non-recursive TCO, without any special logic. For fun let's look at a tail-call version of factorial:

/* Tail-call recursive helper for factorial */
int factorial_accumulate(int n, int accum) {
  return n < 2 ? accum : factorial_accumulate(n - 1, n * accum);
}

int factorial(int n) { return factorial_accumulate(n, 1); }

Note that this is not the naive version of factorial implemented as n * factorial(n-1), because in the naive version the function needs to make a recursive call, and then multiply the return value of the recursive call before returning, meaning that the recursive call is not in the tail position.

Here's what I get when I compile this with gcc -O2:

(gdb) disas factorial
Dump of assembler code for function factorial:
   0x0000000000000040 <+0>:	mov    $0x1,%eax
   0x0000000000000045 <+5>:	cmp    $0x1,%edi
   0x0000000000000048 <+8>:	jle    0x60 <factorial+32>
   0x000000000000004a <+10>:	nopw   0x0(%rax,%rax,1)
   0x0000000000000050 <+16>:	imul   %edi,%eax
   0x0000000000000053 <+19>:	sub    $0x1,%edi
   0x0000000000000056 <+22>:	cmp    $0x1,%edi
   0x0000000000000059 <+25>:	jne    0x50 <factorial+16>
   0x000000000000005b <+27>:	ret
   0x000000000000005c <+28>:	nopl   0x0(%rax)
   0x0000000000000060 <+32>:	ret
End of assembler dump.

(gdb) disas factorial_accumulate
Dump of assembler code for function factorial_accumulate:
   0x0000000000000020 <+0>:	mov    %esi,%eax
   0x0000000000000022 <+2>:	cmp    $0x1,%edi
   0x0000000000000025 <+5>:	jle    0x3b <factorial_accumulate+27>
   0x0000000000000027 <+7>:	nopw   0x0(%rax,%rax,1)
   0x0000000000000030 <+16>:	imul   %edi,%eax
   0x0000000000000033 <+19>:	sub    $0x1,%edi
   0x0000000000000036 <+22>:	cmp    $0x1,%edi
   0x0000000000000039 <+25>:	jne    0x30 <factorial_accumulate+16>
   0x000000000000003b <+27>:	ret
End of assembler dump.

This is actually pretty interesting! For one thing, in the code for factorial it completely inlined the factorial_accumulate logic, so factorial doesn't call factorial_accumulate at all. However, the compiler still generated code for factorial_accumulate in my object file. If you take a minute to read through this you can see how factorial_accumlate works. The first three instructions which are a sequence of mov, cmp, jle test for the recursion base case (i.e. n < 2) and jump to a ret instruction in this base case. When the condition n < 2 is false then the code enters a loop of instructions imul for the multiply, sub to decrement n, another cmp to see if the base case has been hit, and a jne to return to the imul instruction with the base case isn't hit. When the base case is hit the jne falls through to the ret.

The code for factorial works almost exactly the same way, and you'll notice essentially the same instruction sequence. Some of the registers are different because factorial is called with one parameter instead of two. For some reason it has two ret statements, presumably for alignment reasons (this is also why the nopw and nopl instructions are here, these are no-op instructions that only exist for alignment purposes).