EDSLs as functions

by Adelbert Chang on Oct 26, 2016

technical

This is the second of a series of articles on “Monadic EDSLs in Scala.”

Perhaps the most direct way to start writing an EDSL is to start writing functions. Let’s say we want a language for talking about sets of integers.

trait SetLang {
  def add(i: Int, set: Set[Int]): Set[Int]
  def remove(i: Int, set: Set[Int]): Set[Int]
  def exists(i: Int, set: Set[Int]): Boolean
}

This works… to the extent that we want only to work with scala.collection.Sets. As it stands we cannot talk about other sets such as bloom filters or sets controlled by other threads. Our language isn’t abstract enough, so let’s remove all traces of Set.

trait SetLang[F[_]] {
  def add(i: Int, set: F[Int]): F[Int]
  def remove(i: Int, set: F[Int]): F[Int]
  def exists(i: Int, set: F[Int]): Boolean

  // Given unknown F we no longer know how to create an empty set
  // so we add the capability to our language
  def empty: F[Int]
}

We’ve parameterized our language with a higher-kinded type which represents the context of our set. A similar parameterization could be done with a *-kinded type (e.g. SetLang[A]) but since this series focuses on monadic EDSLs, the choice is made for us.

Now we can write mini-programs which talk about some abstract set yet to be determined.

def program[F[_]](lang: SetLang[F]): Boolean = {
  import lang._
  exists(10, remove(5, add(10, add(5, empty))))
}

Interpretation of our program is done by implementing SetLang and passing an instance into program.

However, our language is still not abstract enough. Replacing Set with F allows us to swap in implementations of sets, but doesn’t allow us to talk about the context. Consider the behavior of exists if F represents some remote set. Since exists returns a Boolean, checking membership must be a synchronous operation despite the set living on another node.

It’s also tedious to thread the set through each method manually.

We can solve both problems by generalizing the use of F to some context that is able to read and write to some set (think Set[Int] => (Set[Int], A)).

trait SetLang[F[_]] {
  def add(i: Int): F[Unit]
  def remove(i: Int): F[Unit]
  def exists(i: Int): F[Boolean]

  // No longer need `empty` since the "context" has it already
}

SetLang can now talk about the effects around interpretation, such as asynchronity.

import scala.concurrent.Future

type AsyncSet[A] = Set[Int] => Future[(Set[Int], A)]

object AsyncSet extends SetLang[AsyncSet] {
  def add(i: Int): Set[Int] => Future[(Set[Int], Unit)] = ???

  def remove(i: Int): Set[Int] => Future[(Set[Int], Unit)] = ???

  def exists(i: Int): Set[Int] => Future[(Set[Int], Boolean)] = ???
}

This new encoding introduces a new but important problem: how do we combine the results of multiple calls to SetLang methods? In the previous encoding we could add and remove by threading the set from one call to the next. With this change to represent a context, it’s not clear how to do that.

Fortunately we are now in a position to leverage a powerful tool: monads. By extending our set language to be monadic we recover composition in an elegant way. The Cats library is used for demonstration purposes, but the discussion applies equally to Scalaz.

import cats.Monad
import cats.implicits._

trait SetLang[F[_]] {
  // See: https://typelevel.org/blog/2016/09/30/subtype-typeclasses.html
  // for why the `Monad` instance is defined as a member as opposed to inherited
  def monad: Monad[F]

  def add(i: Int): F[Unit]
  def remove(i: Int): F[Unit]
  def exists(i: Int): F[Boolean]
}

def program[F[_]](lang: SetLang[F]): F[Boolean] = {
  import lang._
  implicit val monadInstance = monad
  for {
    _ <- add(5)
    _ <- add(10)
    _ <- remove(5)
    b <- exists(10)
  } yield b
}

Defining an interpreter starts by identifying a target context. Since the context computes values while updating state, this suggests the state monad.

import cats.data.State

object ScalaSet extends SetLang[State[Set[Int], ?]] {
  val monad = Monad[State[Set[Int], ?]]

  def add(i: Int): State[Set[Int], Unit] =
    State.modify(_ + i)

  def remove(i: Int): State[Set[Int], Unit] =
    State.modify(_ - i)

  def exists(i: Int): State[Set[Int], Boolean] =
    State.inspect(_(i))
}
val state = program[State[Set[Int], ?]](ScalaSet)
// state: cats.data.StateT[cats.Eval,scala.collection.immutable.Set[Int],Boolean] = cats.data.StateT@ce9f626

state.run(Set.empty).value
// res5: (scala.collection.immutable.Set[Int], Boolean) = (Set(10),true)

Note that calling program did not require any context-specific knowledge - we could define another interpreter, perhaps one that talks to a set concurrently.

import cats.data.StateT
import scala.concurrent.{ExecutionContext, Future}

// Asynchronous state
def AsyncSet(implicit ec: ExecutionContext): SetLang[StateT[Future, Set[Int], ?]] =
  new SetLang[StateT[Future, Set[Int], ?]] {
    val monad = Monad[StateT[Future, Set[Int], ?]]

    def add(i: Int): StateT[Future, Set[Int], Unit] =
      StateT.modify(_ + i)

    def remove(i: Int): StateT[Future, Set[Int], Unit] =
      StateT.modify(_ - i)

    def exists(i: Int): StateT[Future, Set[Int], Boolean] =
      StateT.inspect(_(i))
  }
// No changes to `program` required
val result = program(AsyncSet(ExecutionContext.global))
// result: cats.data.StateT[scala.concurrent.Future,scala.collection.immutable.Set[Int],Boolean] = cats.data.StateT@1c029382

SetLang captures the structure of a computation, but leaves open its interpretation.

Monad transformers and classes

As it turns out, SetLang is an example of an encoding often referred to as MTL-style.

Monads in monads

Among the motivations for monad classes is to remove the need to specify monad transformer stacks. The following example is adapted from Functional Programming with Overloading and Higher-Order Polymorphism by Professor Mark P. Jones.

Consider a program that is open to failure and computes with some state. This suggests a combinator of Either and State, both of which have monad transformers. All that is left is to decide which transformer to use.

type App1[A] = EitherT[State[S, ?], Error, A]
            // State[S, Either[Error, A]]
            // S => (S, Either[Error, A])

type App2[A] = StateT[Either[Error, ?], S, A]
            // S => Either[Error, (S, A)]

While App1 and App2 are both valid compositions, the semantics of the compositions differ. App1 describes a program where the computation of a value at each transition may fail - but any changes are preserved - whereas App2 describes a program where the entire transition may fail.

We can abstract away the difference by creating a type class which provides the relevant operations we need.

trait MonadError[F[_], E] {
  def monad: Monad[F]

  def raiseError[A](e: E): F[A]
  def handleErrorWith[A](fa: F[A])(f: E => F[A]): F[A]
}

trait MonadState[F[_], S] {
  def monad: Monad[F]

  def get: F[S]
  def set(s: S): F[Unit]
}

Similar type classes exist for the Reader and Writer data types. These type classes are provided in both Cats and Scalaz, with some caveats.

With these type classes in place we can write functions against these as opposed to specific transformer stacks. Furthermore our functions can specify exactly what operations they need which helps correctness and parametricity.

import cats.{MonadError, MonadState}
import cats.data.{EitherT, State, StateT}

def program[F[_]](implicit F0: MonadError[F, String],
                           F1: MonadState[F, Int]): F[Int] =
  F0.flatMap(F1.get) { i =>
    F0.raiseError[Int]("fail")
  }

Our program can then be instantiated with either transformer stack.

import cats.implicits._

// At the time of this writing Cats does not have these instances
// so they are defined here.
//
// Additionally, both Cats and Scalaz 7 have encoding issues
// with these MTL type classes which requires us to redefine Monad when
// defining MonadState instances, despite there already being one.
implicit def eitherTMonadState[F[_], E, S](implicit F: MonadState[F, S]): MonadState[EitherT[F, E, ?], S] =
  new MonadState[EitherT[F, E, ?], S] {
    def get: EitherT[F, E, S] =
      EitherT(F.get.map(Right(_)))

    def set(s: S): EitherT[F, E, Unit] =
      EitherT(F.set(s).map(Right(_)))

    def flatMap[A, B](fa: EitherT[F, E, A])
                     (f: A => EitherT[F, E, B]): EitherT[F, E, B] =
      fa.flatMap(f)

    def pure[A](x: A): EitherT[F, E, A] =
      EitherT.pure(x)

    def tailRecM[A, B](a: A)(f: A => EitherT[F, E, Either[A, B]]): EitherT[F, E, B] =
      EitherT.catsDataMonadErrorForEitherT[F, E].tailRecM(a)(f)
  }

implicit def stateTMonadError[F[_], E, S](implicit F: MonadError[F, E]): MonadError[StateT[F, S, ?], E] =
  new MonadError[StateT[F, S, ?], E] {
    def handleErrorWith[A](fa: StateT[F, S, A])(f: E => StateT[F, S, A]): StateT[F, S, A] =
      StateT[F, S, A] { (s: S) =>
        val state: F[(S, A)] = fa.run(s)
        F.handleErrorWith(state)(e => f(e).run(s))
      }

    def raiseError[A](e: E): StateT[F, S, A] =
      StateT.lift(F.raiseError(e))

    def flatMap[A, B](fa: StateT[F, S, A])(f: A => StateT[F, S, B]): StateT[F, S, B] =
      fa.flatMap(f)

    def pure[A](x: A): StateT[F, S, A] = StateT.pure(x)

    def tailRecM[A, B](a: A)(f: A => StateT[F, S, Either[A, B]]): StateT[F, S, B] =
      StateT.catsDataMonadStateForStateT[F, S].tailRecM(a)(f)
  }

type App1[A] = EitherT[State[Int, ?], String, A]

type App2[A] = StateT[Either[String, ?], Int, A]
val app1 = program[App1]
// app1: App1[Int] = EitherT(cats.data.StateT@5fdc056d)

val app2 = program[App2]
// app2: App2[Int] = cats.data.StateT@72493a33

Composing languages

From one angle we can view our set language, or more generally any EDSL in MTL-style, as an effect like MonadError and MonadState. From another angle we can view MonadError and MonadState as EDSLs that talk about errors and stateful computations. We can eliminate the distinctions by renaming SetLang to MonadSet and treating it as a type class.

import cats.Monad
import cats.implicits._

trait MonadSet[F[_]] {
  def monad: Monad[F]

  def add(i: Int): F[Unit]
  def remove(i: Int): F[Unit]
  def exists(i: Int): F[Boolean]
}

Composing multiple languages then becomes adding constraints to functions, and interpretation becomes instantiating type parameters that satisfy the constraints.

trait MonadCalc[F[_]] {
  def monad: Monad[F]

  def lit(i: Int): F[Int]
  def plus(l: F[Int], r: F[Int]): F[Int]
}

def setProgram[F[_]: MonadSet](i: Int): F[Boolean] =
  implicitly[MonadSet[F]].exists(i)

def calcProgram[F[_]: MonadCalc]: F[Int] = {
  val calc = implicitly[MonadCalc[F]]
  calc.plus(calc.lit(1), calc.lit(2))
}

def composedProgram[F[_]: MonadCalc: MonadSet]: F[Boolean] = {
  implicit val monad: Monad[F] = implicitly[MonadCalc[F]].monad
  for {
    i <- calcProgram[F]
    b <- setProgram(i)
  } yield b
}

// Instance

// Instances are defined together but nothing is stopping us from defining
// these separately, perhaps one in the MonadSet object and another in the
// SetState object.
implicit val stateInstance: MonadSet[State[Set[Int], ?]] with MonadCalc[State[Set[Int], ?]] =
  new MonadSet[State[Set[Int], ?]] with MonadCalc[State[Set[Int], ?]] {
    val monad = Monad[State[Set[Int], ?]]

    def add(i: Int): State[Set[Int], Unit] = State.modify(_ + i)

    def remove(i: Int): State[Set[Int], Unit] = State.modify(_ - i)

    def exists(i: Int): State[Set[Int], Boolean] = State.inspect(_(i))

    def lit(i: Int): State[Set[Int], Int] = State.pure(i)
    def plus(l: State[Set[Int], Int], r: State[Set[Int], Int]): State[Set[Int], Int] =
      (l |@| r).map(_ + _)
  }
val result = composedProgram[State[Set[Int], ?]].run(Set.empty[Int]).value
// result: (scala.collection.immutable.Set[Int], Boolean) = (Set(),false)

As before, composedProgram, calcProgram, and setProgram are defined independent of interpretation, so alternative interpretations simply require defining appropriate instances.

A note about laws

Type classes should come with laws - this lets us give meaning to their use. The Monoid type class requires data types to have an associative binary operation and a corresponding identity element. These laws allow us to parallelize batch operations, such as partitioning a List[A] into multiple chunks to be scattered across threads or machines and gathered back.

Since our EDSLs are type classes, we should think about what laws we expect to hold. Below are some possible candidates for laws:

// MonadSet
set *> add(i)    *> remove(i) = set
set *> remove(i) *> exists(i) = false
set *> add(i)    *> exists(i) = true

// MonadCalc - these are just the Monoid laws
plus(lit(0), x) = plus(x, lit(0)) = x
plus(x, plus(y, z)) = plus(plus(x, y), z)

Next up we’ll take a look at some pitfalls of this approach, and a modified encoding that solves some of them.

This article was tested with Scala 2.11.8, Cats 0.7.2, kind-projector 0.9.0, and si2712fix-plugin 1.2.0 using tut.