A Monads Approach for Beginners, in Scala

8 minute read

There are so many tutorials on Monads it’s not even funny. In fact, it’s not funny at all - the more analogies people make, the more confused readers and listeners seem to be, because how could you bridge the gap between burritos, programmable semicolons, applicative functors? Perhaps you could, IF you already knew IT.

And by IT, I mean the M word.

As much as I like to talk by way of analogies, I believe monads are not well-served by them. I scoured the web trying to find commonalities that would stick, and I couldn’t. Instead, one common question I kept seeing in comments was: OK, so how are these things useful? So in this (probably long) article I’ll address them by the invariable question in a curious programmer’s mind: why are they useful to ME?

Unlike burritos and abstract math, the goal of this article is to write code; we will show how monads solve programming problems, how they make code easier to understand and how they make us happy and more productive.

The idea of a monad is not so difficult but it’s notoriously hard to wrap your head around because it’s so general. Once it clicks, it feels simple and profound, which probably explains the itch for the gazillions of tutorials on the subject. My hope is that after this article, you too can go “hm, that wasn’t so hard”.

Part 1: the pattern

Monads are types which take values and do something interesting to them by following a certain structure. I spent 5 minutes writing that sentence. 99% of the value of monads is in their structure, which I hope will become part of your mental model. So in what follows we’ll focus on the structure.

Assume for example, I want to define a class which prevents multi-threaded access to a value:

case class SafeValue[+T](private val internalValue: T) {
  def get: T = synchronized {
    // note: this is just a simplified example and not 100% foolproof
    // thread safety is a very complicated topic and I don't want to digress there
    // replace the logic here with code that might be interesting for any reason
    internalValue
  }
}

Also assume some external API gave me one of those SafeValues and I want to transform it into another SafeValue:

val safeString: SafeValue[String] = gimmeSafeValue("Scala is awesome") // obtained from elsewhere
val string = safeString.get
val upperString = string.toUpperCase()
val upperSafeString = SafeValue(upperString)

Pay close attention to this pattern. Someone gives you a safe value, you unwrap it, transform it, then create another safe value.

Instead of dealing with things “sequentially” in this way, we can define the transformation in the safe value class itself:

// more elegant (and safe): use a function to transform the element, inside the wrapper
case class SafeWrapper2[+T](private val internalValue: T) {
  def transform[S](transformer: T => SafeWrapper2[S]): SafeWrapper2[S] = synchronized {
    transformer(internalValue)
  }
}

In this way, we can simply do

val safeString2 = SafeWrapper2("Scala is awesome")
val upperSafeString2 = safeString2.transform(string => SafeWrapper2(string.toUpperCase()))

This is more concise, more elegant, much safer, and more OO-encapsulated, because you never need to get the internal value of the wrapper.

This pattern (and its refactor) is very common, and is composed of two things:

1. the ability to wrap a value into my (more interesting) type - in OO terms this is just a "constructor"; we call this _unit_, or _pure_, or _apply_
2. a function that transforms a wrapper into another wrapper (perhaps of another type) in the same style as the above - we usually call this _bind_ or _flatMap_

When you’re dealing with this extract-transform-wrap pattern (I’ll call this ETW), you’ve created the conditions for a monad.

Part 2: examples from the jungle

Example 1: A census application

Assume you’re populating a database with people, and you have the following code:

case class Person(firstName: String, lastName: String) {
    // you have a requirement that these fields must not be nulls
    assert(firstName != null && lastName != null)
}

def getPerson(firstName: String, lastName: String): Person =
  if (firstName != null) {
    if (lastName != null) {
      Person(firstName.capitalize, lastName.capitalize)
    } else {
      null
    }
  } else {
    null
  }

This is a defensive Java-style implementation, which is hard to read, hard to debug and very easy to nest further. Notice the pattern: “extract” the first name, transform it, “extract” the last name, transform it, then wrap them into something interesting.

We can more elegantly solve this with Options. You’ve probably seen them before: they’re like lists, except they have at most one element. Options have both structural elements of monads: a “constructor” and a flatMap. We can therefore simplify this code:

def getPersonBetter(firstName: String, lastName: String): Option[Person] =
  Option(firstName).flatMap { fName =>
    Option(lastName).flatMap { lName =>
      Option(Person(fName.capitalize, lName.capitalize))
    }
  }

and we can then reduce this to for-comprehensions:

def getPersonFor(firstName: String, lastName: String): Option[Person] = for {
  fName <- Option(firstName)
  lName <- Option(lastName)
} yield Person(fName.capitalize, lName.capitalize)

Notice that it’s now much easier to read this code, because you have no more reason to do the null-checking logic in your head, you just leave to the Options. The null-checking part is the “interesting” thing they do.

Example 2: asynchronous fetches

Imagine you’re calling some async services for your online store, like fetching from some external resource:

case class User(id: String)
case class Product(sku: String, price: Double)
// ^ never use Double for currency IN YOUR LIFE

def getUser(url: String): Future[User] = Future { ... }
def getLastOrder(userId: String): Future[Product] = Future { ... }

and assume I want to get Daniel’s last order from the store so I can send him a VAT invoice. However, all I know is the URL I need to call:

val userFuture = getUser("my.store.com/users/daniel")

Now what?

Beginners usually get stuck here: “how do I wait for this Future to end so I can get its value and move on with my life?”. So they onComplete the crap out of it:

userFuture.onComplete {
  case Success(User(id)) =>
    val lastOrder = getLastOrder(id)
    lastOrder.onComplete {
      case Success(Product(_, p)) =>
        val vatIncludedPrice = p * 1.19
        // then pass it on
    }
}

This is really bad, nested, hard to read and doesn’t cover all cases anyway. Notice you’re following the same ETW pattern: trying to extract an object, doing something with it and then wrap it back. Enter flatMaps again:

val vatIncludedPrice = getUser("my.store.com/users/daniel")
    .flatMap(user => getLastOrder(user.id)) // relevant monad bit
    .map(_.price * 1.19) // then do your thing

which you can then use to pass on to your logic computing VATs. Or even better:

val vatIncludedPriceFor = for {
  user <- getUser("my.store.com/users/daniel")
  order <- getLastOrder(user.id)
} yield order.price * 1.19

Example 3: double-for “loops”

Lists follow the same pattern. Here’s an example: you want to return a “checkerboard” (i.e. a cartesian product) of all possible combinations of elements from two lists.

val numbers = List(1,2,3)
val chars = List('a', 'b', 'c')

Same pattern: get a value from numbers, get a value from chars, build a tuple from them, “repeat”. Solvable with flatMaps:

val checkerboard = numbers.flatMap(number => chars.map(char => (number, char)))

Or better:

val checkerboard2 = for {
  n <- numbers
  c <- chars
} yield (n, c)

which looks oddly similar to Java or imperative “loops”. The reason why this looks similar is that monads inherently describe sequential ETW computations: extract this, then transform, then wrap. If you want to process it further, same thing: extract, transform, wrap.

If you got this far, congratulations! The gist is this: whenever you need to ETW, you probably need a monad, which needs 1) a constructor, 2) a flatMap in the style of the above.

Part 3: the annoying axioms

By reading the comments of various monad tutorials, I noticed that people are usually pushed away by the abstraction caused by math. I hope the examples above set the stage for some inherent properties of monads that I hope will prove self-evident. Here goes:

Property 1: left-identity, a.k.a. the “if I had it” pattern

The ETW pattern is described in a single line: MyThing(x).flatMap(f). You can express that in two ways:

  1. If you have the x the monad was built with, you want to apply the transformation on it, so you get f(x).
  2. If you don’t, you need the ETW pattern, so you need MyThing(x).flatMap(f).

Either way, you get the same thing. For all monads, MyThing(x).flatMap(f) == f(x).

Property 2: left-associativity, a.k.a. the useless wrap

Assume you have MySuperMonad(x). What should happen when you say

MySuperMonad(x).flatMap(x => MySuperMonad(x))

If it’s not obvious yet, make it concrete:

  • example with lists: List(x).flatMap(x => List(x)) = ?
  • example with Futures: Future(x).flatMap(x => Future(x))= ?

Nothing changes, right? An inherent property of monads is that MySuperMonad(x).flatMap(x => MySuperMonad(x)) is the same as MySuperMonad(x).

Property 3: right-associativity, a.k.a. ETW-ETW

This one is critical, and describes the correctness of multiple flatMap chains on monads. Say you have a list and some transformations:

val numbers = List(1,2,3)
val incrementer = (x: Int) => List(x, x + 1)
val doubler = (x: Int) => List(x, 2 * x)

If you apply both transformations in flatMaps, you get:

[1,2,3].flatMap(incrementer).flatMap(doubler) =
[1,2,2,3,3,4].flatMap(doubler) =
[1,2,2,4,2,4,3,6,3,6,4,8]

Now, if you look at that list like this,

[1,2,2,4,  2,4,3,6,  3,6,4,8]

you can see which numbers were generated by each initial element of the list. In other words, we have

[
  incrementer(1).flatMap(doubler),
  incrementer(2).flatMap(doubler),
  incrementer(3).flatMap(doubler)
]

just as if a fictitious function x => incrementer(x).flatMap(doubler) was applied to each element. In other words,

numbers.flatMap(incrementer).flatMap(doubler) == numbers.flatMap(x => incrementer(x).flatMap(doubler))

Replace numbers with any monad and put any functions inside, and you get the boo-hoo-hoo scary axiom:

MyMonad(x).flatMap(f).flatMap(g) == MyMonad(x).flatMap(x => f(x).flatMap(g))

The pattern goes like this:

  • extract
  • transform & wrap
  • extract again
  • transform & wrap

which is what the two flatMaps do; in the right hand side we’re compressing the last 3 steps:

  • extract
  • transform & wrap & ETW

Epilogue

I hope this dive into monads was as down-to-earth as possible. We use monads everywhere, often without noticing. Monads are just a mathematical label attached to something we otherwise use constantly, and the axioms look quite scary and bland. However, once you understand what monads want to solve, you naturally get a deep insight into the very nature of computation, and that alone is worth the effort of this tutorial and many before (and after) it.