Tail Recursion

We already talked about recursion in the previous topic. Tail recursion is another concept associated with recursion. I hope you already understand the notion of a head and the tail. If you have a list like (5,4,3,2,1,0), the first element is the head, and the rest is the tail. I can remove the head and create a new list. It looks like below.
(4,3,2,1,0)
Now, 4 is the new head and rest of the list is a tail. You look at any recursive algorithm, and you will realize that it always works on a list. When I need to think something recursively, I follow these three steps.

  1. Identify the list
  2. Implement the termination condition
  3. Compute the head and recurs with the tail

The last item in the above list is the part where you need to think. First two steps are same for all recursion implementations.
Let's take a recursive example.

                                        
    def factorial(n: Int): Int = {
        if (n <= 0) 
            return 1
        else
            return n * factorial(n-1)
        }                
                                    

The above code is a recursive implementation of a factorial function. You can quickly test the three steps that I listed in the beginning.

  1. Identify the list - The list is (5,4,3,2,1,0)
  2. Implement the termination condition - The termination is the zero because that is the last item in the list. To compute the factorial, we cannot multiply by a zero, so we return one.
  3. Compute the head and recurs with the tail - The Compute part is a multiplication of the head with the new head in the tail. So, we recurs with the tail to get the next head and multiply it with the current head.

However, there is a catch here. You cannot multiply the first head until you discover the next head from the next call. So, your runtime environment will keep the first head in the stack and make a new call. It goes on until you reach the termination. Every call is waiting for the next call to complete.
So the actual multiplication happens when we reach the termination condition, and the chain of calls starts folding. So, each recursive call requires a frame in the stack. That is a big problem. If your recursion goes on for thousands of calls, you may run out of memory. That is a significant limitation with the recursion.
You can quickly prove it showing the stack by throwing an exception. Let me show an example to throw an exception at the termination.

                                        
    def factorial(n: Int): Int = {
        if (n <= 0) 
            throw new Exception("boom!")
        else
            return n * factorial(n-1)      
        }   

        factorial(5)
        /*Output:-
        java.lang.Exception: boom!
                at .factorial(<console>:10)
                at .factorial(<console>:12)
                at .factorial(<console>:12)
                at .factorial(<console>:12)
                at .factorial(<console>:12)
                at .factorial(<console>:12)
                at .<init>(<console>:9)
        */    
                                    

You can see that we get an exception, and the stack trace is visible. The stack trace shows six function calls. There are six entries in the stack, one for each recursion call. Every entry consumes some memory. However, if you implement same logic using a loop, these stack frames are not used. So, what do we conclude?


Is a loop better than a recursion?

If you consider memory requirements and performance, the answer is a definite Yes. So, loops are more efficient than recursion. Scala compiler knows this and tries to optimize the recursive calls. However, we have to redesign the recursion is a way that there are no unfinished operations for the next recursive call. That is where the tail recursion come in. So, we can define the tail recursion as below.

A tail call or a tail recursion is a function call performed as the last action.

What does it mean?
It means that your recursive call should be the last operation in your function. If we look at the factorial function, It waits for the recursive call and then multiplies the result with n. So the multiplication is the last action. To make it a tail recursive, we need to change it in a way that recursive call becomes the final action instead of the multiplication.
How can you do that?
Not that simple but yes, you can do it by applying some trick. In the old logic, the first multiplication happens when we reach the termination condition, and we return one as the last number. Instead of returning one at the termination, we can take it in the beginning and perform the multiplication before the recursion call. To implement this trick, you need two input parameters.
Here is the code.

                                        
    def  factorial(n: Int, f:Int): Int = {
        if (n <= 0) 
            return f
        else
            return factorial(n-1, n*f)      
        }
    //So, f must be one for this function to work. 
    factorial(5,1)                 
                                    

This code implements tail recursive factorial because the recursive call is the last action. When the Scala compiler recognizes that it is a tail recursion, it will optimize the function by converting it to a standard loop. We will not realize the change, but the compiler will do it internally. This optimization will overcome the performance and memory problem. If you want to prove it, you can again implement an exception and double check it.

                                        
    def  factorial(n: Int, f:Int): Int = {
        if (n <= 0) 
            throw new Exception("boom!")
        else
            return factorial(n-1, n*f)      
        }       
        
    factorial(5,1)
    /* Output:- 
    java.lang.Exception: boom!
            at .factorial(<console>:10)
            at .<init>(<console>:9)   
    */                                    
                                    

You can see that there is only one function call in the stack trace because the compiler converts the tail recursion to a loop. If you think that a factorial function with two input parameters looks odd, you can wrap it into an outer function.
Here is the code.

                                        
    def factorial(i:Int):Int = {
        def  tFactorial(n: Int, f:Int): Int = {       
                if (n <= 0)  f
                else  tFactorial(n-1, n*f)      
            }    
        return  tFactorial(i,1)
    }                    
                                    

So the factorial takes just one parameter, and it internally calculates the factorial using the tFactorial function. The tFactorial is a local tail recursive function.
Keep reading for more interesting functional concepts.

Read More

Pure Functions | Referential Transparency | Benefits of pure functions | First class functions | Higher order function | Anonymous functions | Immutability | Tail Recursion | Expressions in Scala | Lazy Evaluations | Pattern Matching | Closures

By Prashant Pandey -


You will also like:


Kafka Core Concepts

Learn Apache Kafka core concepts and build a solid foundation on Apache Kafka.

Learning Journal

Apache Spark Introduction

What is Apache Spark and how it works? Learn Spark Architecture.

Learning Journal

Pattern Matching

Scala takes the credit to bring pattern matching to the center.

Learning Journal

Pure Function benefits

Pure Functions are used heavily in functional programming. Learn Why?

Learning Journal

Immutability in FP

The literal meaning of Immutability is unable to change? How to program?

Learning Journal