Every call to a function requires keeping the formal parameters and other variables in the memory for as long as the function doesn’t return control back to the caller.
These variables and the function’s stack frame (the entry in the call stack) must be retained until it’s no longer required.
For example, consider the following function definition:
1
2
3
4
5
6
7
8
9
public static int factorial(int n) {
if(n <= 1) {
return 1;
}
int contribution = n;
int sub = factorial(n-1);
int result = contribution * sub;
return result;
}
Each call to function factorial
has to wait for the next call to return the value to it, before it can be terminated, and the frame stack removed.
Let’s say the origin call is factorial(4)
(which calls factorial(3)
, which calls factorial(2)
, which calls the terminal case in factorial(1)
). Green represents the stack frame on the top of the call stack (active function).
At the time factorial(1)
executes, the call stack looks like the following:
Instead of creating all these variables inside the function, if we can (and pay attention to the statement carefully), pass the state of the stack so far as parameters, we wouldn’t need to maintain the previous stack frames. Oh my several Gods - that would be amazing!
P.S. Java still doesn’t support tail optimization… 😞
1
2
3
4
5
6
public static int factorial(int n, int currentState) {
if(n <= 1) {
return currentState;
}
return factorial(n-1, currentState * n);
}
The value for currentState
in the initial call should be 1. Sample client:
1
2
3
4
5
public class Client {
public static void main(String[] args) {
int val = factorial(4, 1); //1 being the initial value of currentState
}
}
Now, only the current stack frame needs to be kept in the memory and the previous stack frames for the recursive function can be discarded.
At the time factorial(1, 24)
executes, the call stack looks like the following:
You can also create a second proxy function so that you don’t have to pass the second parameter during the initial call, like,
1
2
3
public static int factorial(int n) {
return factorial(n, 1);
}
As long as the number (or order) of parameters is different, two functions can have the same name.
1
2
3
4
5
6
public static int sumDigits(int n) { //assuming n >= 0
if(n == 0) {
return 0;
}
return n%10 + sumDigits(n/10);
}
1
2
3
4
5
6
public static int sumDigits(int n, int currentState) {
if(n == 0) {
return currentState;
}
return sumDigits(n/10, currentState + n%10);
}
1
2
3
4
5
6
7
8
9
10
11
public static int sumEvenDigits(int n) { //assuming n >= 0
if(n == 0) {
return 0;
}
if(n%2 == 0) {
return n%10 + sumEvenDigits(n/10);
}
else {
return sumEvenDigits(n/10);
}
}
1
2
3
4
5
6
7
8
9
10
11
public static int sumEvenDigits(int n, int currentState) { //assuming n >= 0
if(n == 0) {
return currentState;
}
if(n%2 == 0) {
return sumEvenDigits(n/10, currentState + n%10);
}
else {
return sumEvenDigits(n/10, currentState);
}
}
1
2
3
4
5
6
7
8
9
10
11
12
public static boolean isPalindrome(String str) {
if(str == null) {
return false;
}
if(str.length() < 2) {
return true;
}
if(str.charAt(0) != str.charAt(str.length()-1)) {
return false;
}
return isPalindrome(str.substring(1, str.length()-1));
}
No change needed, already tail-optimized as the last statement is the recursive call and nothing else.
1
2
3
4
5
6
7
8
9
public static String reverse(String str) {
if(str == null || str.length() < 2) {
return str;
}
char first = str.charAt(0);
char last = str.charAt(str.length()-1);
String remaining = str.substring(1, str.length()-1);
return last + reverse(remaining) + first;
}
The current version is not convenient for tail optimization, so we’ll do a slight modification.
1
2
3
4
5
6
7
8
public static String reverse(String str) {
if(str == null || str.length() < 2) {
return str;
}
char first = str.charAt(0);
String remaining = str.substring(1);
return reverse(remaining) + first;
}
Now, it’s ready to tail optimize.
1
2
3
4
5
6
7
8
9
public static String reverse(String str, String constructed) {
if(str == null) {
return null;
}
if(str.isEmpty()) {
return constructed;
}
return reverse(str.substring(1), str.charAt(0) + constructed);
}
1
2
3
4
5
6
public static int gcd(int a, int b) { //assuming a, b >= 0
if(b == 0) {
return a;
}
return gcd(b, a%b);
}
Already tail-optimized.
1
2
3
4
5
6
public static int fib(int n) {
if(n == 0 || n == 1) {
return n;
}
return fib(n-1) + fib(n-2);
}
*This is a tricky one. Please see 1 and 2 for some nice explanations.
1
2
3
4
5
6
7
8
9
10
11
12
//initially called as fib(n, 0, 1)
public static int fib(int n, int a, int b) {
if(n==0) {
//it will only ever be called directly from client,
//never from another fib call
return a; //which WILL be 0
}
if(n==1) {
return b;
}
return fib(n-1, b, a+b);
}