Blogpost

A gentle introduction to Monads

Monads?

What the hell are monads you ask?

The formal wikipedia definition says:

“In functional programming, a monad is an abstraction that allows structuring programs generically. Supporting languages may use monads to abstract away boilerplate code needed by the program logic. Monads achieve this by providing their own data type (a particular type for each type of monad), which represents a specific form of computation, along with one procedure to wrap values of any basic type within the monad (yielding a monadic value) and another to compose functions that output monadic values (called monadic functions).”

 

Great …

It is said that there’s a curse with Monads. I’m not making this up and it’s called the “monad tutorial fallacy” and the legend says that when you finally understand them, you lose the ability to explain it to others. The curse is actually described in 4 steps:

  1. person X doesn’t understand monads
  2. person X works long and hard, and groks monads
  3. person X experiences amazing feeling of enlightenment, wonders why others are not similarly enlightened
  4. person X gives horrible, incomplete, inaccurate, oversimplified, and confusing explanation of monads to others which probably makes them think that monads are stupid, dumb, worthless, overcomplicated, unnecessary, something incorrect, or a mental wank-ercise

 

Great motivation … So why write about this beast?

Monads are building blocks in the functional programming world and we see that functional programming is integrated to some extent in most modern object oriented programming languages.

Imagine you are an object oriented programmer in a Java 6 world. Our code is a representation of logic and how many times do you find below operations in your code:

  • Operations that work on nullable/optional values
  • Operations that return a value or an error
  • Operations that perform operations on a List/Array
  • Operations that have an asynchronous nature.
  • Etc.

 

Automatically you are responsible for coding below logic:

  • Execute the rest of the code only if the value isn’t null
  • Execute the rest of the code only if there was no error
  • Execute the rest of the code once for each value in the list
  • Execute the rest of the code once the asynchronous code is finished
  • Etc.

I see a repeating pattern emerging …

Wouldn’t it be nice if you could just focus on the core of your logic without worrying about generic logic like “executing the rest of the code once for each value in the list”.

 

Wait a second …

Java 8 is influenced by the functional programming world and introduced concepts like Streams on Lists and Optional. Those concepts seem to fit in the description above …

Lists / Options / Futures are monads and you’ve been using their Java equivalents without even knowing they were monads.

 

0*nSnDSELPgCQkYMDC

 

Ok so you have this concept and why should I care about the internals of these concepts? I use some implementations, I like those implementations and I don’t have to understand them completely to take advantage of their power.

My challenge is twofold here:

  1. Break the curse and explain what the hell a monad is
  2. Give insights why it’s a powerful concept that is so much more than the implementations most object oriented programmers are familiar with (Java Optional, etc)

 

I learned the concept of monads and it gave me a very different perspective on concepts that I took for granted for many years. The new perspective influenced my code in many ways. (whether it was Java or Scala or Python)

Monads are everywhere but it doesn’t mean that we need to do anything special with them. You can use them, you can spot them, or completely ignore them and you will be fine.

But when dealing with large amounts of data or with complex data modelling, I immediately noticed the advantages of taking the monad perspective into consideration when deciding how to write code. Monads are not the answer to every problem but understanding them will give you a different perspective on subtle things and you will think about every “if-check” in your code in new ways.

This new way of thinking promotes code that is easier to test / more robust / more suitable for parallel / concurrent programming. Monads is a concept that is entangled with functional programming concepts like pure functions, higher order functions, immutability, etc and many of those concepts contribute to your code quality.

 

0*PkPfhGg8fwslCUZ4

 

Scala?

Almost all the code and the examples in this blogpost are written in Scala. Why Scala you ask?

  1. Today JavaScript is evolving to Java and Java is constantly taking over Scala features. So basically everybody wants to be Scala …
  2. Scala is the preferred programming language for many data engineers.
  3. Scala can make senior developers feel very junior. Old programmers like me like to feel young again.
  4. Scala is sexy
 

0*4uMJn5fmhsLOU4ua

 

I don’t want to hurt the feelings of Kotlin fanboys, but always remember that Kotlin is a better Java and Scala is a more powerful Java …

 

Back to monads?

Let’s get started and do a conceptual dive into these mysterious monads.

Informally, a monad is anything with a constructor and a flatMap method and it’s a mechanism for sequencing computations …

Well that’s it?

This is the core that you need to keep remembering and I will guide you step by step through the above statement.

But before we jump to monads, we have to talk about functors first. Functors? What is that?

Monads and Functors are both wrapper concepts that have roots in the same math theory (https://en.wikipedia.org/wiki/Category_theory). We won’t go into the mathematical details, but think of both concepts as wrappers. The wrapper can be represented as a box around something. In the below image you can see a box that can contain an Integer value.

The wrapper concept is pretty useless if we are not allowed to do operations on the wrapper. In the below example you can see that we can put a value in the box and can apply a transformation on the value inside the box (if it exists).

 

0*uh4tBSWjidlCTdhJ

 

Functors and monads should both provide ways to put something inside the box. The difference between both is that you can call the map() function on the functor and you can call the flatmap() function on the monad. In the above image we have a box that may contain a value. In the Java world this is an Optional (Option in Scala).

Let that sink in …

  1. Yes, a Java Optional is a functor because it has a Map function and it’s a wrapper concept.
  2. Yes, a Java Optional is a monad because it has a FlapMap function and it’s a wrapper concept.

Ok, so our definition so far is pretty simple. A functor is a wrapper that has a map function and a monad is a wrapper with a flatmap function. Because flatmap is a special map function (= you “map” first and then you “flatten”), we can also conclude that every monad is also a functor.

Are we there yet? No keep hanging in there …

Because monads have mathematical roots, we need to understand the laws that apply to monads. Let’s formalise the box / map / flatmap principles :

    				
    					def unit: A → F[A]
    				
    			
    				
    					def map: F[A] → (A → B) → F[B]
    				
    			
    • We need a Flatten functionality that allows us to unpack a type A inside 2 boxes F.
    				
    					def flatten: F[F[A]] → F[A]
    				
    			

    With these three principles in mind we can start to write a formal Scala definition of these principles.

      1. We need a unit function that takes a parameter of type A and puts it inside the wrapper M.
      2. Our M type has a trait (interface in JAVA) that says that we need a FlapMap function on our wrapper type. That FlapMap function takes a function as parameter and that function maps a type A to a type B (wrapped inside our wrapper/monad type).
    1.  

    0*1ny8d uMT7ribgIi

    				
    					unit(x).flatMap(f) == f(x)
    				
    			

    right-identity law:

    				
    					m.flatMap(unit) == m
    
    				
    			
    				
    					m.flatMap(f).flatMap(g) == m.flatMap(x ⇒ f(x).flatMap(g))
    				
    			
    				
    					println("Let's dive into the world of monads laws ...")
    
    /*
    left-identity law:
    unit(x).flatMap(f) == f(x)
    */
    
    def unit(x:Int):Option[Int] = { Some(x) }
    def plus5(x:Int):Option[Int] = { Some(x + 5) }
    def plus6(x:Int):Option[Int] = { Some(x + 6) }
    println("left identity law -> both below statement are equal")
    println(unit(5).flatMap(plus5))
    println(plus5(5))
    
    
    /*
    right-identity law:
    m.flatMap(unit) == m
    */
    
    println("right identity law -> both below statement are equal")
    val optionMonad:Option[Int] = unit(5)
    println(optionMonad.flatMap(unit))
    println(unit(5))
    
    
    /*
    associativity law:
    m.flatMap(f).flatMap(g) == m.flatMap(x ⇒ f(x).flatMap(g))
    */
    
    println("associativity law -> both below statement are equal")
    def f:Int => Option[Int] = (x:Int) => plus5(x).flatMap(plus6)
    println(optionMonad.flatMap(plus5).flatMap(plus6))
    println(optionMonad.flatMap(f))
    				
    			

    This gives us the following output :

    Let's dive into the world of monads laws ...
    left identity law -> both below statement are equal
    Some(10)
    Some(10)
    right identity law -> both below statement are equal
    Some(5)
    Some(5)
    associativity law -> both below statement are equal
    Some(16)
    Some(16)

    I hear you thinking. This is all great, but I don’t see what’s the fuss all about…

    In the functional programming world it’s all about functions. Writing the functions is usually not the hard part, but then we have to glue them together in a functional world. Remember our first statement: “a monad is anything with a constructor and a flatMap method and it’s a mechanism for sequencing computations”.

    So let’s first do some sequencing. We will create a user service that retrieves a user object. That user object may have a child (recursive data). I want to see if our user has a grandchild …

    				
    					println("Let's do some sequencing ...")
    
    // I want to get a user and see if he has a grandchild
    case class User ( 
        child: Option[User], 
        name: String 
    )
    
    object UserService {
    	def loadUser(name: String): Option[User] = Some(User(child = Some(User(child = Some(User(child = None, name = "Jan")), name = "Jef")), name = "Joe"))
    }
    
    
    // sequencing via flatmap
    val grandChild:Option[User] = UserService.loadUser("Kristof")
      .flatMap(u => u.child)
      .flatMap(u => u.child)
    
    println(grandChild.map(_.name).getOrElse("No grandchild found ..."))
    
    
    // sequencing via map
    val grandChildMap:Option[User] = UserService.loadUser("Kristof")
    .map(u => u.child)
    .map(u => u.get.child.get)
    
    println(grandChild.map(_.name).getOrElse("No grandchild found ..."))
    				
    			
    This gives us the following output :
    Let's do some sequencing ...
    Jan
    Jan

    This is great as the flatmap operator allows us to chain operations together and to keep doing that for as long as we want. This shows the advantage of flatmap over a map function.

    1. Using the map operator would allow us to do one map operation and then we run into trouble …
    2. Using the flatmap operator you can see that our flatmap function remains very clean.

     

    This is a principle that generalises very well. The above example is very simple but how does it help us in chaining different functions. Remember that we are in a functional programming world. We like to use pure functions that have no side effects. Below you can see 2 pure functions.

    				
    					def parseInt(str: String): Option[Int] = scala.util.Try(str.toInt).toOption
    def divide(a: Int, b: Int): Option[Int] = if (b == 0) None else Some(a / b)
    				
    			

    I want to use them and combine them in useful logic. We are in luck. They have Option return types and that happens to be a Monad … This means that we have a flatmap function and this means that we can build sequences with these functions.

    				
    					def parseInt(str: String): Option[Int] = scala.util.Try(str.toInt).toOption
    def divide(a: Int, b: Int): Option[Int] = if (b == 0) None else Some(a / b)
    
    println("the monad way ...")
    
    def stringDivideBy(aStr: String, bStr: String): Option[Int] =
      parseInt(aStr).flatMap { 
        aNum => parseInt(bStr).flatMap { 
          bNum => divide(aNum, bNum)
        }
      }
    
    println(stringDivideBy("8","4"))
    println(stringDivideBy("a","4"))
    println(stringDivideBy("8","0"))
    				
    			

    This gives us the following output :

    the monad way ...
    Some(2)
    None
    None

    The semantics of the stringDivideBy method:

    • the first call to parseInt returns a None or a Some;
    • if it returns a Some, the flatMap method calls our function and passes us the integer aNum;
    • the second call to parseInt returns a None or a Some;
    • if it returns a Some, the flatMap method calls our function and passes us bNum;
    • the call to divide returns a None or a Some, which is our result.

     

    At each step, flatMap chooses whether to call our function, and our function generates the next computation in the sequence.

    monads 5 1

    • A FlapMap function
    • An optional Map function

     

    1. We create our FlatMap function and remember that we will use this for sequencing a chain of operations. It’s identical to our map implementation but now we append the message instead of replacing it.
     
    				
    					case class Debuggable(value: Int, msg: String) {
    
      def map(f: Int => Int): Debuggable = {
        val nextValue = f(value)
        Debuggable(nextValue, msg)
      }
    
      def flatMap(f: Int => Debuggable): Debuggable = {
        val nextValue = f(value)
        Debuggable(nextValue.value, msg + "\n" + nextValue.msg)
      }
    }
    				
    			

    We seem to have a custom monad. Now let’s create some functions that use our monad (they are all pure functions that return our monad).

    				
    					def pureFunctionPlus1(a: Int): Debuggable = {
      val result = a + 1
      Debuggable(result, s"f: input: $a, result: $result")
    }
    
    
    def pureFunctionPlus2(a: Int): Debuggable = {
      val result = a + 2
      Debuggable(result, s"g: input: $a, result: $result")
    }
    
    
    def pureFunctionPlus3(a: Int): Debuggable = {
      val result = a + 3
      Debuggable(result, s"h: input: $a, result: $result")
    }
    				
    			
    				
    					val finalResult = pureFunctionPlus1(100).flatMap { 
      fResult => pureFunctionPlus2(fResult).flatMap { 
        gResult => pureFunctionPlus3(gResult).map { 
          hResult => hResult
        }
      }
    }
    				
    			

    This gives us the following results:

    • final result :
    106
    • log message :
    f: input: 100, result: 101
    g: input: 101, result: 103
    h: input: 103, result: 106

    This is a big step and you can let that sink in … You can create your own monads for your specific needs. The functions above didn’t care how they were combined and in what position they were executed. They just knew that they needed to return a Debuggable object.

    Scala provides syntax that hides the ugly flatmap chaining constructions and you can write it more concise like this:

    				
    					val finalResultScala = for {
      fRes <- pureFunctionPlus1(100)
      gRes <- pureFunctionPlus2(fRes)
      hRes <- pureFunctionPlus3(gRes)
    } yield hRes
    
    
    println(s"final Scala for comprehension value: ${finalResultScala.value}")
    println(s"final Scala for comprehension msg: ${finalResultScala.msg}")
    				
    			

    This gives us the following output :

    final Scala for comprehension value: 106
    final Scala for comprehension msg:
    f: input: 100, result: 101
    g: input: 101, result: 103
    h: input: 103, result: 106

    In the monad world you just witnessed a writer monad (https://en.wikipedia.org/wiki/Monad_(functional_programming)#Writer_monad). The code was not perfect and the code was also not very reusable (you have to return a Debuggable object in your pure functions) but it shows that you can write custom monads for your own needs.

    The imperative solution to this problem is way simpler. Why prefer the monad?

    Let’s say we start to profile code that is using multiple threads. Standard imperative logging technisch can result in interleaved messages from multiple threads. The writer monad doesn’t have this problem and using the monad version will actually simplify your solution. Remember that functional programming shines in the parallel / concurrent programming world.

    This is just a very specific use case and there are many problem patterns that can be solved with Monads. Think of them as one of the gang of four patterns (https://www.gofpatterns.com/) for the functional programming world. Absorb these best practices and decide if the monad implementation is the best fit for your use case.

    A good monad can take in pure clean functions that are robust / easy to test / easy to read / etc. Those clean functions can be used in a monad implementation or in a non monad implementation. The monad should take away the boilerplate context for you. You can write code and use it in the Future Monad context and reuse it in another Monad context or even use it in a non Monad context.

    I hope you now have a feeling of what Monads are and why they matter in certain software development contexts.

     

    Sources:

    Kristof Slechten scaled e1677765417962

    Kristof Slechten

    Software Crafter

    Kristof Slechten behaalde een master Informatica aan de VUB en is gespecialiseerd in projecten die betrekking hebben op big data & machine learning. Momenteel is Kristof aan de slag bij Imes Dexis waar hij onderzoekstrajecten rond machine learning uitwerkt. Daarnaast werkt Kristof mee aan verschillende interne onderzoeksprojecten rond AI.