Primitive recursion with fix and Mu

by Stephen Compall on Apr 14, 2014

technical

Consider the simple cons-list datatype.

import scalaz.Equal, scalaz.std.option._, scalaz.syntax.std.option._,
       scalaz.std.anyVal._, scalaz.std.function._,
       scala.reflect.runtime.universe.reify

sealed abstract class XList[A]
final case class XNil[A]() extends XList[A]
final case class XCons[A](head: A, tail: XList[A]) extends XList[A]

And a simple function over this structure. Say, a simple summing function.

def sum(xs: XList[Int]): Int = xs match {
  case XNil() => 0
  case XCons(x, xs) => x + sum(xs)
}

And that seems to work out alright.

scala> val nums = XCons(2, XCons(3, XCons(4, XCons(42, XNil()))))
nums: XCons[Int] = XCons(2,XCons(3,XCons(4,XCons(42,XNil()))))

scala> sum(nums)
res0: Int = 51

Has it ever struck you as curious that, though its own value was required to construct a value like sum, the system has no problem with that?

Oh, well, that’s just a recursive function, you say. Well, what’s so special about recursive functions? Why do they get special treatment so that they can define themselves with themselves?

Induction and termination

First, let’s be clear: there’s a limit to how much of sum can be used in its own definition.

Let us consider the moral equivalent of the statement “this function gives the sum of a list of integers because it is the function that gives the sum of a list of integers.”

def sum2(xs: XList[Int]): Int = sum2(xs)

scalac will compile this definition; it is well-typed. However, it will be nonsensical at runtime, because it is nonsense; it will either throw some exception or loop forever.

Let us consider a similar case: the infinite list of 42s.

scala> val fortyTwos: Stream[Int] = 42 #:: fortyTwos
fortyTwos: Stream[Int] = Stream(42, ?)

scala> fortyTwos take 5 toList
res0: List[Int] = List(42, 42, 42, 42, 42)

The definition of fortyTwos is like that of sum; it uses its own value while constructing said value. A similar definition to sum2 is, likewise, nonsense, though scalac can catch this particular case:

scala> val fortyTwos2: Stream[Int] = fortyTwos2
<console>:7: error: value fortyTwos2 does nothing other than call itself recursively
       val fortyTwos2: Stream[Int] = fortyTwos2
                                     ^

Obviously, functions aren’t special; non-function values, like functions, can be defined using their own values. But how can we characterize the difference between the good, terminating definitions, and the bad, nonterminating definitions?

Proof systems like Coq and Agda perform a strong check on recursive definitions; for definitions like sum, they require the recursion match the structure of the data type, just as ours does, so that each recursive call is known to operate over smaller data. For definitions like fortyTwos, they apply other strategies. In Scala, we have to make do with informality.

I like to think of it this way: a recursive definition must always perform at least one inductive step. sum does so because, in the recursive case, it gives “supposing I have the sum of tail, the sum is the head plus that.” fortyTwos does because it says “the value fortyTwos is 42 consed onto the value fortyTwos.” It is, at least, the start of a systematic way of thinking about terminating recursive definitions.

Abstracting the recursion

Now that we have a framework for thinking about what is required in a recursive definition, we can start abstracting over it.

The above recursive definitions were accomplished with special language support: the right-hand side of any term definition, val or def, can refer to the thing being so defined. Scalaz provides the fix function, which, if it were provided intrinsically, would eliminate the need for this language support.

def fix[A](f: (=> A) => A): A = {
  lazy val a: A = f(a)
  a
}

In this definition, the value returned by f is the value given to it as an argument. It’s a by-name argument because that’s how we enforce the requirement: f must perform at least one inductive step in the definition of its result, though it can refer to that result by its argument, which we enforce by requiring it to return a value before evaluating that argument.

Let’s redefine sum with fix, after importing it from scalaz.std.function.

val sum3: XList[Int] => Int = fix[XList[Int] => Int](rec => {
  case XNil() => 0
  case XCons(x, xs) => x + rec.apply(xs)
})

And Scala thinks that’s alright.

scala> sum3(nums)
res1: Int = 51

The interesting thing here is that sum3’s definition doesn’t refer to the name sum3; the recursion is entirely inside the fix argument. So one advantage of fix is that it’s easy to write recursive values as expressions without giving them a name.

For example, there’s the definition of fortyTwos:

scala> fix[Stream[Int]](42 #:: _)
res2: Stream[Int] = Stream(42, ?)

scala> res2 take 5 toList
res3: List[Int] = List(42, 42, 42, 42, 42)

For special data structures

It can be inconvenient to avoid evaluating the argument when providing an induction step. Fortunately, the requirement that f be nonstrict in its argument is too strong to characterize the space of values that can be defined with fix-style recursion.

For a given data type, there’s often a way to abstract out the nonstrictness. For example, here’s an Equal instance combinator that is fully evaluated, but doesn’t force the argument until after the (equivalent) result has been produced.

def lazyEqual[A](A: => Equal[A]): Equal[A] = new Equal[A] {
  def equal(l: A, r: A): Boolean = A equal (l, r)
  override def equalIsNatural = A.equalIsNatural
}

Given that, we can produce a fix variant for Equal that passes the Equal argument strictly. You’re simply not allowed to invoke any of the typeclass’s methods.

def fixEq[A](f: Equal[A] => Equal[A]): Equal[A] =
  fix[Equal[A]](A => f(lazyEqual(A)))

And now, we have the machinery to build a fully derived Equal instance for XList, without function recursion, by defining the base case and inductive step!

def `list equal`[A: Equal]: Equal[XList[A]] =
  fixEq[XList[A]](implicit rec =>
    Equal.equalBy[XList[A], Option[(A, XList[A])]]{
      case XNil() => None
      case XCons(x, xs) => Some((x, xs))
    })

That works out to interesting compiled output. Note especially the last line, and its (strict) use of rec towards the end.

scala> reify(fixEq[XList[Int]](implicit rec =>
     |     Equal.equalBy[XList[Int], Option[(Int, XList[Int])]]{
     |       case XNil() => None
     |       case XCons(x, xs) => Some((x, xs))
     |     }))
res10: reflect.runtime.universe.Expr[scalaz.Equal[XList[Int]]] = 
Expr[scalaz.Equal[XList[Int]]]($read.fixEq[$read.XList[Int]](((implicit rec) =>
 Equal.equalBy[$read.XList[Int], Option[Tuple2[Int, $read.XList[Int]]]]
 (((x0$1) => x0$1 match {
  case $read.XNil() => None
  case $read.XCons((x @ _), (xs @ _)) => Some.apply(Tuple2.apply(x, xs))
}))(option.optionEqual(tuple.tuple2Equal(anyVal.intInstance, rec))))))

f0, a binary serialization library, uses a similar technique to help define codecs on recursive data structures.

What about XList?

If we can abstract out the idea of recursive value definitions, what about recursive type definitions? Well, thanks to higher kinds, sure! Scalaz doesn’t provide it, but it is commonly called Mu.

final case class Mu[F[_]](value: F[Mu[F]])

We have to put a class in the middle of it so that we don’t have an infinite type; Haskell has a similar restriction. But the principle is the same as with fix: feed one datatype induction step F to the higher-order type Mu and it will feed F’s result back to itself.

For example, here is the equivalent definition of XList with Mu.

type XList2Step[A] = {type λ[α] = Option[(A, α)]}
type XList2[A] = Mu[XList2Step[A]#λ]

Note the typelambda’s similarity to the second type argument to #equalBy above. And for demonstration, the isomorphism with XList.

def onetotwo[A](xs: XList[A]): XList2[A] = xs match {
  case XNil() => Mu[XList2Step[A]#λ](None)
  case XCons(x, xs) => Mu[XList2Step[A]#λ](Some((x, onetotwo(xs))))
}

def twotoone[A](xs: XList2[A]): XList[A] =
  xs.value cata ({case (x, xs) => XCons(x, twotoone(xs))}, XNil())

Of course, fix lends itself to both of these definitions; I have left its use off here. But let’s check those functions:

scala> onetotwo(nums)
res11: XList2[Int] = Mu(Some((2,Mu(Some((3,Mu(Some((4,Mu(Some((42,Mu(None)))))))))))))

scala> twotoone(res11)
res12: XList[Int] = XCons(2,XCons(3,XCons(4,XCons(42,XNil()))))

fix over Mu

And, finally, the associated general Equal definition for Mu. The contramap step is just noise to deal with the fact that the Mu structure has to actually exist; you can ignore it for the most part.

def equalMu[F[_]](fa: Equal[Mu[F]] => Equal[F[Mu[F]]]): Equal[Mu[F]] =
  fixEq[Mu[F]](emf => fa(emf) contramap (_.value))

The evidence we really want is forall a. Equal[a] => Equal[F[a]], but that’s too hard to express in Scala, so this does it in a pinch. All we’re interested in is that we can derive F’s equality given the equality of any type argument given to it. Let’s prove that we have such an Equal-lifter:

// redefined because Tuple2Equal scalaz is strict on equalIsNatural
class Tuple2Equal[A1, A2](_1: Equal[A1], _2: Equal[A2])
    extends Equal[(A1, A2)] {
  def equal(f1: (A1, A2), f2: (A1, A2)) =
    _1.equal(f1._1, f2._1) && _2.equal(f1._2, f2._2)
  override def equalIsNatural: Boolean = _1.equalIsNatural && _2.equalIsNatural
}
implicit def tup2eq[A1: Equal, A2: Equal] =
  new Tuple2Equal[A1, A2](implicitly, implicitly)

abstract class Blah // just a placeholder

scala> {implicit X: Equal[Blah] => implicitly[Equal[XList2Step[Int]#λ[Blah]]]}
res4: scalaz.Equal[Blah] => scalaz.Equal[Option[(Int, Blah)]] = <function1>

And now that we have F equality, we’re done, because Mu is Fs all the way down.

scala> equalMu[XList2Step[Int]#λ](implicit fa => implicitly)
res5: scalaz.Equal[Mu[[α]Option[(Int, α)]]] = scalaz.Equal$$anon$2@de52bcf

scala> res5 equal (onetotwo(nums), onetotwo(nums))
res6: Boolean = true

scala> res5 equal (onetotwo(nums), onetotwo(XCons(3,nums)))
res7: Boolean = false

This article was tested with Scala 2.10.4 & Scalaz 7.0.6.