Comparing TypedDatasets with Spark's Datasets

Goal: This tutorial compares the standard Spark Datasets API with the one provided by Frameless' TypedDataset. It shows how TypedDatasets allow for an expressive and type-safe api with no compromises on performance.

For this tutorial we first create a simple dataset and save it on disk as a parquet file. Parquet is a popular columnar format and well supported by Spark. It's important to note that when operating on parquet datasets, Spark knows that each column is stored separately, so if we only need a subset of the columns Spark will optimize for this and avoid reading the entire dataset. This is a rather simplistic view of how Spark and parquet work together but it will serve us well for the context of this discussion.

import spark.implicits._

// Our example case class Foo acting here as a schema
case class Foo(i: Long, j: String)

// Assuming spark is loaded and SparkSession is bind to spark
val initialDs = spark.createDataset( Foo(1, "Q") :: Foo(10, "W") :: Foo(100, "E") :: Nil )
// initialDs: org.apache.spark.sql.Dataset[Foo] = [i: bigint, j: string]

// Assuming you are on Linux or Mac OS
initialDs.write.parquet("/tmp/foo")

val ds = spark.read.parquet("/tmp/foo").as[Foo]
// ds: org.apache.spark.sql.Dataset[Foo] = [i: bigint, j: string]

ds.show()
// +---+---+
// |  i|  j|
// +---+---+
// |  1|  Q|
// |100|  E|
// | 10|  W|
// +---+---+
//

The value ds holds the content of the initialDs read from a parquet file. Let's try to only use field i from Foo and see how Spark's Catalyst (the query optimizer) optimizes this.

// Using a standard Spark TypedColumn in select()
val filteredDs = ds.filter($"i" === 10).select($"i".as[Long])
// filteredDs: org.apache.spark.sql.Dataset[Long] = [i: bigint]

filteredDs.show()
// +---+
// |  i|
// +---+
// | 10|
// +---+
//

The filteredDs is of type Dataset[Long]. Since we only access field i from Foo the type is correct. Unfortunately, this syntax requires handholding by explicitly setting the TypedColumn in the select statement to return type Long (look at the as[Long] statement). We will discuss this limitation next in more detail. Now, let's take a quick look at the optimized Physical Plan that Spark's Catalyst generated.

filteredDs.explain()
// == Physical Plan ==
// *(1) Filter (isnotnull(i#2419L) AND (i#2419L = 10))
// +- *(1) ColumnarToRow
//    +- FileScan parquet [i#2419L] Batched: true, DataFilters: [isnotnull(i#2419L), (i#2419L = 10)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/tmp/foo], PartitionFilters: [], PushedFilters: [IsNotNull(i), EqualTo(i,10)], ReadSchema: struct<i:bigint>
// 
//

The last line is very important (see ReadSchema). The schema read from the parquet file only required reading column i without needing to access column j. This is great! We have both an optimized query plan and type-safety!

Unfortunately, this syntax is not bulletproof: it fails at run-time if we try to access a non existing column x:

ds.filter($"i" === 10).select($"x".as[Long])
// org.apache.spark.sql.AnalysisException: [UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name `x` cannot be resolved. Did you mean one of the following? [`i`, `j`].;
// 'Project ['x]
// +- Filter (i#2419L = cast(10 as bigint))
//    +- Relation [i#2419L,j#2420] parquet
// 
// 	at org.apache.spark.sql.errors.QueryCompilationErrors$.unresolvedAttributeError(QueryCompilationErrors.scala:307)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.org$apache$spark$sql$catalyst$analysis$CheckAnalysis$$failUnresolvedAttribute(CheckAnalysis.scala:147)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$6(CheckAnalysis.scala:266)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$6$adapted(CheckAnalysis.scala:264)
// 	at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:244)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$5(CheckAnalysis.scala:264)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$5$adapted(CheckAnalysis.scala:264)
// 	at scala.collection.immutable.Stream.foreach(Stream.scala:533)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$2(CheckAnalysis.scala:264)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$2$adapted(CheckAnalysis.scala:182)
// 	at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:244)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis0(CheckAnalysis.scala:182)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis0$(CheckAnalysis.scala:164)
// 	at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis0(Analyzer.scala:188)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis(CheckAnalysis.scala:160)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis$(CheckAnalysis.scala:150)
// 	at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis(Analyzer.scala:188)
// 	at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$executeAndCheck$1(Analyzer.scala:211)
// 	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:330)
// 	at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:208)
// 	at org.apache.spark.sql.execution.QueryExecution.$anonfun$analyzed$1(QueryExecution.scala:77)
// 	at org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:138)
// 	at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$2(QueryExecution.scala:219)
// 	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:546)
// 	at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:219)
// 	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
// 	at org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:218)
// 	at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:77)
// 	at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:74)
// 	at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:66)
// 	at org.apache.spark.sql.Dataset.<init>(Dataset.scala:206)
// 	at org.apache.spark.sql.Dataset.<init>(Dataset.scala:212)
// 	at org.apache.spark.sql.Dataset.select(Dataset.scala:1595)
// 	at repl.MdocSession$MdocApp0$$anonfun$1.apply(TypedDatasetVsSparkDataset.md:78)
// 	at repl.MdocSession$MdocApp0$$anonfun$1.apply(TypedDatasetVsSparkDataset.md:78)

There are two things to improve here. First, we would want to avoid the as[Long] casting that we are required to type for type-safety. This is clearly an area where we may introduce a bug by casting to an incompatible type. Second, we want a solution where reference to a non existing column name fails at compilation time. The standard Spark Dataset can achieve this using the following syntax.

ds.filter(_.i == 10).map(_.i).show()
// +-----+
// |value|
// +-----+
// |   10|
// +-----+
//

This looks great! It reminds us the familiar syntax from Scala. The two closures in filter and map are functions that operate on Foo and the compiler will helps us capture all the mistakes we mentioned above.

ds.filter(_.i == 10).map(_.x).show()
// error: value x is not a member of repl.MdocSession.MdocApp0.Foo
// ds.filter(_.i == 10).map(_.x).show()
//                          ^^^

Unfortunately, this syntax does not allow Spark to optimize the code.

ds.filter(_.i == 10).map(_.i).explain()
// == Physical Plan ==
// *(1) SerializeFromObject [input[0, bigint, false] AS value#2471L]
// +- *(1) MapElements <function1>, obj#2470: bigint
//    +- *(1) Filter <function1>.apply
//       +- *(1) DeserializeToObject newInstance(class repl.MdocSession$MdocApp0$Foo), obj#2469: repl.MdocSession$MdocApp0$Foo
//          +- *(1) ColumnarToRow
//             +- FileScan parquet [i#2419L,j#2420] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/tmp/foo], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<i:bigint,j:string>
// 
//

As we see from the explained Physical Plan, Spark was not able to optimize our query as before. Reading the parquet file will required loading all the fields of Foo. This might be ok for small datasets or for datasets with few columns, but will be extremely slow for most practical applications. Intuitively, Spark currently does not have a way to look inside the code we pass in these two closures. It only knows that they both take one argument of type Foo, but it has no way of knowing if we use just one or all of Foo's fields.

The TypedDataset in Frameless solves this problem. It allows for a simple and type-safe syntax with a fully optimized query plan.

import frameless.TypedDataset
import frameless.syntax._
val fds = TypedDataset.create(ds)
// fds: TypedDataset[Foo] = [i: bigint, j: string]

fds.filter(fds('i) === 10).select(fds('i)).show().run()
// +-----+
// |value|
// +-----+
// |   10|
// +-----+
//

And the optimized Physical Plan:

fds.filter(fds('i) === 10).select(fds('i)).explain()
// == Physical Plan ==
// *(1) Project [i#2419L AS value#2555L]
// +- *(1) Filter (isnotnull(i#2419L) AND (i#2419L = 10))
//    +- *(1) ColumnarToRow
//       +- FileScan parquet [i#2419L] Batched: true, DataFilters: [isnotnull(i#2419L), (i#2419L = 10)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/tmp/foo], PartitionFilters: [], PushedFilters: [IsNotNull(i), EqualTo(i,10)], ReadSchema: struct<i:bigint>
// 
//

And the compiler is our friend.

fds.filter(fds('i) === 10).select(fds('x))
// error: No column Symbol with shapeless.tag.Tagged[String("x")] of type A in repl.MdocSession.MdocApp0.Foo
// fds.filter(fds('i) === 10).select(fds('x))
//                                      ^

Differences in Encoders

Encoders in Spark's Datasets are partially type-safe. If you try to create a Dataset using a type that is not a Scala Product then you get a compilation error:

class Bar(i: Int)

Bar is neither a case class nor a Product, so the following correctly gives a compilation error in Spark:

spark.createDataset(Seq(new Bar(1)))
// error: Unable to find encoder for type repl.MdocSession.MdocApp0.Bar. An implicit Encoder[repl.MdocSession.MdocApp0.Bar] is needed to store repl.MdocSession.MdocApp0.Bar instances in a Dataset. Primitive types (Int, String, etc) and Product types (case classes) are supported by importing spark.implicits._  Support for serializing other types will be added in future releases.
// spark.createDataset(Seq(new Bar(1)))
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

However, the compile type guards implemented in Spark are not sufficient to detect non encodable members. For example, using the following case class leads to a runtime failure:

case class MyDate(jday: java.util.Calendar)
spark.createDataset(Seq(MyDate {
  val cal = new java.util.GregorianCalendar()
  cal.setTime(new java.util.Date(System.currentTimeMillis))
  cal
}))
// org.apache.spark.SparkUnsupportedOperationException: [ENCODER_NOT_FOUND] Not found an encoder of the type java.util.Calendar to Spark SQL internal representation. Consider to change the input type to one of supported at 'https://spark.apache.org/docs/latest/sql-ref-datatypes.html'.
// 	at org.apache.spark.sql.errors.ExecutionErrors.cannotFindEncoderForTypeError(ExecutionErrors.scala:172)
// 	at org.apache.spark.sql.errors.ExecutionErrors.cannotFindEncoderForTypeError$(ExecutionErrors.scala:167)
// 	at org.apache.spark.sql.errors.ExecutionErrors$.cannotFindEncoderForTypeError(ExecutionErrors.scala:218)
// 	at org.apache.spark.sql.catalyst.ScalaReflection$.encoderFor(ScalaReflection.scala:400)
// 	at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$encoderFor$3(ScalaReflection.scala:394)
// 	at scala.collection.immutable.List.map(List.scala:293)
// 	at org.apache.spark.sql.catalyst.ScalaReflection$.encoderFor(ScalaReflection.scala:382)
// 	at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$encoderFor$1(ScalaReflection.scala:247)
// 	at scala.reflect.internal.tpe.TypeConstraints$UndoLog.undo(TypeConstraints.scala:73)
// 	at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects(ScalaReflection.scala:426)
// 	at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects$(ScalaReflection.scala:425)
// 	at org.apache.spark.sql.catalyst.ScalaReflection$.cleanUpReflectionObjects(ScalaReflection.scala:42)
// 	at org.apache.spark.sql.catalyst.ScalaReflection$.encoderFor(ScalaReflection.scala:244)
// 	at org.apache.spark.sql.catalyst.ScalaReflection$.encoderFor(ScalaReflection.scala:227)
// 	at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$.apply(ExpressionEncoder.scala:51)
// 	at org.apache.spark.sql.Encoders$.product(Encoders.scala:315)
// 	at org.apache.spark.sql.LowPrioritySQLImplicits.newProductEncoder(SQLImplicits.scala:264)
// 	at org.apache.spark.sql.LowPrioritySQLImplicits.newProductEncoder$(SQLImplicits.scala:264)
// 	at org.apache.spark.sql.SQLImplicits.newProductEncoder(SQLImplicits.scala:32)
// 	at repl.MdocSession$MdocApp0$$anonfun$18.apply(TypedDatasetVsSparkDataset.md:151)
// 	at repl.MdocSession$MdocApp0$$anonfun$18.apply(TypedDatasetVsSparkDataset.md:151)

In comparison, a TypedDataset will notify about the encoding problem at compile time:

TypedDataset.create(Seq(MyDate {
  val cal = new java.util.GregorianCalendar()
  cal.setTime(new java.util.Date(System.currentTimeMillis))
  cal
}))
// error: could not find implicit value for parameter encoder: frameless.TypedEncoder[repl.MdocSession.MdocApp0.MyDate]
// TypedDataset.create(Seq(MyDate {
// ^

Aggregate vs Projected columns

Spark's Dataset do not distinguish between columns created from aggregate operations, such as summing or averaging, and simple projections/selections. This is problematic when you start mixing the two.

import org.apache.spark.sql.functions.sum
ds.select(sum($"i"), $"i"*2)
// org.apache.spark.sql.AnalysisException: [MISSING_GROUP_BY] The query does not include a GROUP BY clause. Add GROUP BY or turn it into the window functions using OVER clauses.;
// Aggregate [sum(i#2419L) AS sum(i)#2561L, (i#2419L * cast(2 as bigint)) AS (i * 2)#2562L]
// +- Relation [i#2419L,j#2420] parquet
// 
// 	at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:52)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkValidAggregateExpression$1(CheckAnalysis.scala:441)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$19(CheckAnalysis.scala:457)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$19$adapted(CheckAnalysis.scala:457)
// 	at scala.collection.Iterator.foreach(Iterator.scala:943)
// 	at scala.collection.Iterator.foreach$(Iterator.scala:943)
// 	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
// 	at scala.collection.IterableLike.foreach(IterableLike.scala:74)
// 	at scala.collection.IterableLike.foreach$(IterableLike.scala:73)
// 	at scala.collection.AbstractIterable.foreach(Iterable.scala:56)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkValidAggregateExpression$1(CheckAnalysis.scala:457)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$19(CheckAnalysis.scala:457)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$19$adapted(CheckAnalysis.scala:457)
// 	at scala.collection.Iterator.foreach(Iterator.scala:943)
// 	at scala.collection.Iterator.foreach$(Iterator.scala:943)
// 	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
// 	at scala.collection.IterableLike.foreach(IterableLike.scala:74)
// 	at scala.collection.IterableLike.foreach$(IterableLike.scala:73)
// 	at scala.collection.AbstractIterable.foreach(Iterable.scala:56)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkValidAggregateExpression$1(CheckAnalysis.scala:457)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$22(CheckAnalysis.scala:489)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$22$adapted(CheckAnalysis.scala:489)
// 	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
// 	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
// 	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$2(CheckAnalysis.scala:489)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$2$adapted(CheckAnalysis.scala:182)
// 	at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:244)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis0(CheckAnalysis.scala:182)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis0$(CheckAnalysis.scala:164)
// 	at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis0(Analyzer.scala:188)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis(CheckAnalysis.scala:160)
// 	at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis$(CheckAnalysis.scala:150)
// 	at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis(Analyzer.scala:188)
// 	at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$executeAndCheck$1(Analyzer.scala:211)
// 	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:330)
// 	at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:208)
// 	at org.apache.spark.sql.execution.QueryExecution.$anonfun$analyzed$1(QueryExecution.scala:77)
// 	at org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:138)
// 	at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$2(QueryExecution.scala:219)
// 	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:546)
// 	at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:219)
// 	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
// 	at org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:218)
// 	at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:77)
// 	at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:74)
// 	at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:66)
// 	at org.apache.spark.sql.Dataset$.$anonfun$ofRows$1(Dataset.scala:91)
// 	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
// 	at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:89)
// 	at org.apache.spark.sql.Dataset.withPlan(Dataset.scala:4351)
// 	at org.apache.spark.sql.Dataset.select(Dataset.scala:1540)
// 	at repl.MdocSession$MdocApp0$$anonfun$19.apply(TypedDatasetVsSparkDataset.md:177)
// 	at repl.MdocSession$MdocApp0$$anonfun$19.apply(TypedDatasetVsSparkDataset.md:177)

In Frameless, mixing the two results in a compilation error.

// To avoid confusing frameless' sum with the standard Spark's sum
import frameless.functions.aggregate.{sum => fsum}
fds.select(fsum(fds('i)))
// error: polymorphic expression cannot be instantiated to expected type;
//  found   : [Out]frameless.TypedAggregate[repl.MdocSession.MdocApp0.Foo,Out]
//  required: frameless.TypedColumn[repl.MdocSession.MdocApp0.Foo,?]
// fds.select(fsum(fds('i)))
//            ^^^^^^^^^^^^^

As the error suggests, we expected a TypedColumn but we got a TypedAggregate instead.

Here is how you apply an aggregation method in Frameless:

fds.agg(fsum(fds('i))+22).show().run()
// +-----+
// |value|
// +-----+
// |  133|
// +-----+
//

Similarly, mixing projections while aggregating does not make sense, and in Frameless you get a compilation error.

fds.agg(fsum(fds('i)), fds('i)).show().run()
// error: polymorphic expression cannot be instantiated to expected type;
//  found   : [A]frameless.TypedColumn[repl.MdocSession.MdocApp0.Foo,A]
//  required: frameless.TypedAggregate[repl.MdocSession.MdocApp0.Foo,?]
// fds.agg(fsum(fds('i)), fds('i)).show().run()
//                        ^^^^^^^