# How to Sort Lists in Scala with Tail Recursion

Every CS101 course is full of sorting lists. When you learn a programming language, one of the first problems you solve is sorting lists. I get it. You might be tired of sorting lists. Here’s why this article will help you:

- If you’re just getting started with Scala, you’ll at least get to sort lists elegantly with proper FP.
- If you’re more versed, you might find some extra tools here that you can use for any problem.
- You’ll learn to apply tail recursion to a real problem, before some use-case at work demands that from you “yesterday”.

This is one of the (dozens of) problems that we solve in the Scala & FP practice course, which in rough translation is “Scala for coding interviews”.

## 1. Introduction

I’m pretty sure you know the problem but I’ll state it anyway: you’re given a list of integers. Run a method to sort it in ascending order, without mutating the original list. In other words, write a method

```
def sortList(list: List[Int]): List[Int]
```

such that the elements in the resulting list come in ascending order.

## 2. Insertion Sort, the Quick’n’Dirty Way

There are a million sorting algorithms for a list. For the purpose of this article — showing tail recursion on a real problem — we’ll use *insertion sort*, which is most easily understood and read in an FP language, especially when you’re starting out.

For insertion sort, we consider a special operation called *insert*, which can add an element into an *already-sorted* list and returns a new sorted list. For example, if we insert the number 2 into the list `[1,3,4]`

we get the list `[1,2,3,4]`

. Its logic is as follows:

- Inserting a number into an empty list gives a new list containing just that number. Inserting 2 into
`[]`

gives the list`[2]`

. - Inserting a number into a list where the number is smaller than the first (head) of the list means just prepending that number. Example: inserting 1 into
`[2,3,4]`

gives`[1,2,3,4]`

. - Otherwise, the head of the list remains in its place (first), and we recursively insert the number into the rest of the list.

The code will look like this:

```
def insertSorted(element: Int, sortedList: List[Int]): List[Int] =
if (sortedList.isEmpty || element < sortedList.head) element :: sortedList
else sortedList.head :: insertSorted(element, sortedList.tail)
```

An example, following math, inserting 3 into the list `[1,2,4]`

will lead to the following pseudo-Scala:

```
insertSorted(3, [1,2,4]) =
1 :: insertSorted(3, [2,4]) =
1 :: 2 :: insertSorted(3, [4]) =
1 :: 2 :: 3 :: [4] =
[1,2,3,4]
```

After insertion is done, we can declare the logic for sorting at all:

- If the list to be sorted is empty or with a single element, then return that same list. Nothing to sort.
- Otherwise, recursively sort the tail of the list and use
`insertSorted`

for the head and that (sorted) list.

The code is a formal version of the above:

```
def insertionSort(list: List[Int]): List[Int] = {
if (list.isEmpty || list.tail.isEmpty) list
else insertSorted(list.head, insertionSort(list.tail))
}
```

This leads us to the complete solution:

```
def insertionSort(list: List[Int]): List[Int] = {
def insertSorted(element: Int, sortedList: List[Int]): List[Int] =
if (sortedList.isEmpty || element <= sortedList.head) element :: sortedList
else sortedList.head :: insertSorted(element, sortedList.tail)
if (list.isEmpty || list.tail.isEmpty) list
else insertSorted(list.head, insertionSort(list.tail))
}
```

## 3. A Better Sort

The above solution is nice, but it has a problem: it can crash on large lists.

```
insertionSort((1 to 100000).reverse.toList, Ordering[Int]) // using the natural order
```

Output:

```
Exception in thread "main" java.lang.StackOverflowError
at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
at blog.SortingDemo$.insertionSort(SortingDemo.scala:13)
...
```

That’s a stack overflow, caused by the large number of recursions. We can do better.

The solution is to use *tail calls*, or *tail recursion*, so that the stack doesn’t crash. Tail recursion is a mechanism by which the recursive stack frames are reused, so they don’t occupy additional stack memory. This can only happen when recursive calls are the *last expressions on their code path*.

A tail-recursive solution usually involves adding more arguments to the method. Let’s modify `insertSorted`

such that it’s tail recursive:

```
def insertSorted(element: Int, sortedList: List[Int], accumulator: List[Int]): List[Int]
```

In `accumulator`

we’ll store all the numbers smaller than `element`

. At the moment when `element <= sortedList.head`

, all the smaller numbers of the result are in `accumulator`

(in reverse order) and all the bigger numbers are in `sortedList`

. The implementation will work like this:

```
def insertTailrec(element: Int, sortedList: List[Int], accumulator: List[Int]): List[Int] =
if (sortedList.isEmpty || element <= sortedList.head) accumulator.reverse ++ (element :: sortedList)
else insertTailrec(element, sortedList.tail, sortedList.head :: accumulator)
```

This code is a bit harder to digest, and that’s normal. Let’s work through an example:

```
insertTailrec(4, [1,2,3,5], []) ---> else branch --->
insertTailrec(4, [2,3,5], [1]) ---> else branch --->
insertTailrec(4, [3,5], [2,1]) ---> else branch --->
insertTailrec(4, [5], [3,2,1]) ---> first branch --->
[3,2,1].reverse ++ (4 :: [5]) --->
[1,2,3,4,5]
```

By this example, I hope it’s also clear why we needed to `.reverse`

the accumulator at the end of the recursion.

To validate whether a method is tail-recursive, we can add the `@tailrec`

annotation from `scala.annotation.tailrec`

. This will make the compiler check whether the recursive call indeed occurs as the last expression of its code path.

We can apply a similar technique for the “big” sort method:

```
def sortTailrec(list: List[Int], accumulator: List[Int]): List[Int] =
if (list.isEmpty) accumulator
else sortTailrec(list.tail, insertTailrec(list.head, accumulator, Nil))
```

In the accumulator, we store the sorted state of the elements we’ve considered so far. If the list is empty, we’ve sorted everything, so we return the accumulator. Otherwise, we take the list’s head, and we insert it into the (already sorted) accumulator via the (already tailrec) `insertTailrec`

method.

Again, an example would probably illustrate this best. Assume `insertTailrec`

already works correctly. Watch it carefully:

```
sortTailrec([3,1,4,2,5], []) = sortTailrec([1,4,2,5], insertTailrec(3, [], [])) =
sortTailrec([1,4,2,5], [3]) = sortTailrec([4,2,5], insertTailrec(1, [3], [])) =
sortTailrec([4,2,5], [1,3]) = sortTailrec([2,5], insertTailrec(4, [1,3])) =
sortTailrec([2,5], [1,3,4]) = sortTailrec([5], insertTailrec(2, [1,3,4])) =
sortTailrec([5], [1,2,3,4]) = sortTailrec([], insertTailrec(5, [1,2,3,4]) =
sortTailrec([], [1,2,3,4,5]) =
[1,2,3,4,5]
```

The final code looks like this:

```
def insertSortSmarter(list: List[Int]): List[Int] = {
def insertTailrec(element: Int, sortedList: List[Int], accumulator: List[Int]): List[Int] =
if (sortedList.isEmpty || element <= sortedList.head) accumulator.reverse ++ (element :: sortedList)
else insertTailrec(element, sortedList.tail, sortedList.head :: accumulator)
def sortTailrec(list: List[Int], accumulator: List[Int]): List[Int] =
if (list.isEmpty) accumulator
else sortTailrec(list.tail, insertTailrec(list.head, accumulator, Nil))
sortTailrec(list, Nil)
}
```

And sure enough, it works:

```
println(insertSortSmarter((1 to 100000).reverse.toList))
```

```
List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10,...)
```

We can of course generalize the method to work for any type `T`

for which we have an `Ordering[T]`

or some other comparison object in scope, but the goal of the article has been achieved.

## 4. Conclusion

We explored how to sort lists in Scala in just about 7-8 lines of code, how the quick and dirty solution can crash with a stack overflow, and how we can approach a tail-recursive solution that avoids the stack overflow problem. You can adapt this technique to other problems as well — and in the course we squeeze the juice out of tail recursion.

In a future article, I’ll go more philosophical as to how tailrec methods are equivalent to iterative algorithms, but more on that soon…