Comparing TypedDatasets with Spark's Datasets
Goal:
This tutorial compares the standard Spark Datasets API with the one provided by
Frameless' TypedDataset
. It shows how TypedDataset
s allow for an expressive and
type-safe api with no compromises on performance.
For this tutorial we first create a simple dataset and save it on disk as a parquet file. Parquet is a popular columnar format and well supported by Spark. It's important to note that when operating on parquet datasets, Spark knows that each column is stored separately, so if we only need a subset of the columns Spark will optimize for this and avoid reading the entire dataset. This is a rather simplistic view of how Spark and parquet work together but it will serve us well for the context of this discussion.
import spark.implicits._
// Our example case class Foo acting here as a schema
case class Foo(i: Long, j: String)
// Assuming spark is loaded and SparkSession is bind to spark
val initialDs = spark.createDataset( Foo(1, "Q") :: Foo(10, "W") :: Foo(100, "E") :: Nil )
// initialDs: org.apache.spark.sql.Dataset[Foo] = [i: bigint, j: string]
// Assuming you are on Linux or Mac OS
initialDs.write.parquet("/tmp/foo")
val ds = spark.read.parquet("/tmp/foo").as[Foo]
// ds: org.apache.spark.sql.Dataset[Foo] = [i: bigint, j: string]
ds.show()
// +---+---+
// | i| j|
// +---+---+
// |100| E|
// | 1| Q|
// | 10| W|
// +---+---+
//
The value ds
holds the content of the initialDs
read from a parquet file.
Let's try to only use field i
from Foo and see how Spark's Catalyst (the query optimizer)
optimizes this.
// Using a standard Spark TypedColumn in select()
val filteredDs = ds.filter($"i" === 10).select($"i".as[Long])
// filteredDs: org.apache.spark.sql.Dataset[Long] = [i: bigint]
filteredDs.show()
// +---+
// | i|
// +---+
// | 10|
// +---+
//
The filteredDs
is of type Dataset[Long]
. Since we only access field i
from Foo
the type is correct.
Unfortunately, this syntax requires handholding by explicitly setting the TypedColumn
in the select
statement
to return type Long
(look at the as[Long]
statement). We will discuss this limitation next in more detail.
Now, let's take a quick look at the optimized Physical Plan that Spark's Catalyst generated.
filteredDs.explain()
// == Physical Plan ==
// *(1) Filter (isnotnull(i#2419L) AND (i#2419L = 10))
// +- *(1) ColumnarToRow
// +- FileScan parquet [i#2419L] Batched: true, DataFilters: [isnotnull(i#2419L), (i#2419L = 10)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/tmp/foo], PartitionFilters: [], PushedFilters: [IsNotNull(i), EqualTo(i,10)], ReadSchema: struct<i:bigint>
//
//
The last line is very important (see ReadSchema
). The schema read
from the parquet file only required reading column i
without needing to access column j
.
This is great! We have both an optimized query plan and type-safety!
Unfortunately, this syntax is not bulletproof: it fails at run-time if we try to access
a non existing column x
:
ds.filter($"i" === 10).select($"x".as[Long])
// org.apache.spark.sql.AnalysisException: [UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name `x` cannot be resolved. Did you mean one of the following? [`i`, `j`].;
// 'Project ['x]
// +- Filter (i#2419L = cast(10 as bigint))
// +- Relation [i#2419L,j#2420] parquet
//
// at org.apache.spark.sql.errors.QueryCompilationErrors$.unresolvedAttributeError(QueryCompilationErrors.scala:306)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.org$apache$spark$sql$catalyst$analysis$CheckAnalysis$$failUnresolvedAttribute(CheckAnalysis.scala:141)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$6(CheckAnalysis.scala:299)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$6$adapted(CheckAnalysis.scala:297)
// at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:244)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$5(CheckAnalysis.scala:297)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$5$adapted(CheckAnalysis.scala:297)
// at scala.collection.immutable.Stream.foreach(Stream.scala:533)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$2(CheckAnalysis.scala:297)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$2$adapted(CheckAnalysis.scala:215)
// at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:244)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis0(CheckAnalysis.scala:215)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis0$(CheckAnalysis.scala:197)
// at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis0(Analyzer.scala:202)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis(CheckAnalysis.scala:193)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis$(CheckAnalysis.scala:171)
// at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis(Analyzer.scala:202)
// at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$executeAndCheck$1(Analyzer.scala:225)
// at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:330)
// at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:222)
// at org.apache.spark.sql.execution.QueryExecution.$anonfun$analyzed$1(QueryExecution.scala:77)
// at org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:138)
// at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$2(QueryExecution.scala:219)
// at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:546)
// at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:219)
// at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
// at org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:218)
// at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:77)
// at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:74)
// at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:66)
// at org.apache.spark.sql.Dataset.<init>(Dataset.scala:206)
// at org.apache.spark.sql.Dataset.<init>(Dataset.scala:212)
// at org.apache.spark.sql.Dataset.select(Dataset.scala:1597)
// at repl.MdocSession$MdocApp0$$anonfun$1.apply(TypedDatasetVsSparkDataset.md:78)
// at repl.MdocSession$MdocApp0$$anonfun$1.apply(TypedDatasetVsSparkDataset.md:78)
There are two things to improve here. First, we would want to avoid the as[Long]
casting that we are required
to type for type-safety. This is clearly an area where we may introduce a bug by casting to an incompatible
type. Second, we want a solution where reference to a non existing column name fails at compilation time.
The standard Spark Dataset can achieve this using the following syntax.
ds.filter(_.i == 10).map(_.i).show()
// +-----+
// |value|
// +-----+
// | 10|
// +-----+
//
This looks great! It reminds us the familiar syntax from Scala.
The two closures in filter and map are functions that operate on Foo
and the
compiler will helps us capture all the mistakes we mentioned above.
ds.filter(_.i == 10).map(_.x).show()
// error: value x is not a member of repl.MdocSession.MdocApp0.Foo
// ds.filter(_.i == 10).map(_.x).show()
// ^
Unfortunately, this syntax does not allow Spark to optimize the code.
ds.filter(_.i == 10).map(_.i).explain()
// == Physical Plan ==
// *(1) SerializeFromObject [input[0, bigint, false] AS value#2471L]
// +- *(1) MapElements <function1>, obj#2470: bigint
// +- *(1) Filter <function1>.apply
// +- *(1) DeserializeToObject newInstance(class repl.MdocSession$MdocApp0$Foo), obj#2469: repl.MdocSession$MdocApp0$Foo
// +- *(1) ColumnarToRow
// +- FileScan parquet [i#2419L,j#2420] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/tmp/foo], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<i:bigint,j:string>
//
//
As we see from the explained Physical Plan, Spark was not able to optimize our query as before.
Reading the parquet file will required loading all the fields of Foo
. This might be ok for
small datasets or for datasets with few columns, but will be extremely slow for most practical
applications. Intuitively, Spark currently does not have a way to look inside the code we pass in these two
closures. It only knows that they both take one argument of type Foo
, but it has no way of knowing if
we use just one or all of Foo
's fields.
The TypedDataset
in Frameless solves this problem. It allows for a simple and type-safe syntax
with a fully optimized query plan.
import frameless.TypedDataset
import frameless.syntax._
val fds = TypedDataset.create(ds)
// fds: TypedDataset[Foo] = [i: bigint, j: string]
fds.filter(fds('i) === 10).select(fds('i)).show().run()
// +-----+
// |value|
// +-----+
// | 10|
// +-----+
//
And the optimized Physical Plan:
fds.filter(fds('i) === 10).select(fds('i)).explain()
// == Physical Plan ==
// *(1) Project [i#2419L AS value#2555L]
// +- *(1) Filter (isnotnull(i#2419L) AND (i#2419L = 10))
// +- *(1) ColumnarToRow
// +- FileScan parquet [i#2419L] Batched: true, DataFilters: [isnotnull(i#2419L), (i#2419L = 10)], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/tmp/foo], PartitionFilters: [], PushedFilters: [IsNotNull(i), EqualTo(i,10)], ReadSchema: struct<i:bigint>
//
//
And the compiler is our friend.
fds.filter(fds('i) === 10).select(fds('x))
// error: No column Symbol with shapeless.tag.Tagged[String("x")] of type A in repl.MdocSession.MdocApp0.Foo
Differences in Encoders
Encoders in Spark's Datasets
are partially type-safe. If you try to create a Dataset
using a type that is not
a Scala Product
then you get a compilation error:
class Bar(i: Int)
Bar
is neither a case class nor a Product
, so the following correctly gives a compilation error in Spark:
spark.createDataset(Seq(new Bar(1)))
// error: Unable to find encoder for type repl.MdocSession.MdocApp0.Bar. An implicit Encoder[repl.MdocSession.MdocApp0.Bar] is needed to store repl.MdocSession.MdocApp0.Bar instances in a Dataset. Primitive types (Int, String, etc) and Product types (case classes) are supported by importing spark.implicits._ Support for serializing other types will be added in future releases.
However, the compile type guards implemented in Spark are not sufficient to detect non encodable members. For example, using the following case class leads to a runtime failure:
case class MyDate(jday: java.util.Calendar)
spark.createDataset(Seq(MyDate {
val cal = new java.util.GregorianCalendar()
cal.setTime(new java.util.Date(System.currentTimeMillis))
cal
}))
// org.apache.spark.SparkUnsupportedOperationException: [ENCODER_NOT_FOUND] Not found an encoder of the type java.util.Calendar to Spark SQL internal representation. Consider to change the input type to one of supported at 'https://spark.apache.org/docs/latest/sql-ref-datatypes.html'.
// at org.apache.spark.sql.errors.ExecutionErrors.cannotFindEncoderForTypeError(ExecutionErrors.scala:172)
// at org.apache.spark.sql.errors.ExecutionErrors.cannotFindEncoderForTypeError$(ExecutionErrors.scala:167)
// at org.apache.spark.sql.errors.ExecutionErrors$.cannotFindEncoderForTypeError(ExecutionErrors.scala:218)
// at org.apache.spark.sql.catalyst.ScalaReflection$.encoderFor(ScalaReflection.scala:400)
// at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$encoderFor$3(ScalaReflection.scala:394)
// at scala.collection.immutable.List.map(List.scala:293)
// at org.apache.spark.sql.catalyst.ScalaReflection$.encoderFor(ScalaReflection.scala:382)
// at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$encoderFor$1(ScalaReflection.scala:247)
// at scala.reflect.internal.tpe.TypeConstraints$UndoLog.undo(TypeConstraints.scala:73)
// at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects(ScalaReflection.scala:426)
// at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects$(ScalaReflection.scala:425)
// at org.apache.spark.sql.catalyst.ScalaReflection$.cleanUpReflectionObjects(ScalaReflection.scala:42)
// at org.apache.spark.sql.catalyst.ScalaReflection$.encoderFor(ScalaReflection.scala:244)
// at org.apache.spark.sql.catalyst.ScalaReflection$.encoderFor(ScalaReflection.scala:227)
// at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$.apply(ExpressionEncoder.scala:51)
// at org.apache.spark.sql.Encoders$.product(Encoders.scala:315)
// at org.apache.spark.sql.LowPrioritySQLImplicits.newProductEncoder(SQLImplicits.scala:264)
// at org.apache.spark.sql.LowPrioritySQLImplicits.newProductEncoder$(SQLImplicits.scala:264)
// at org.apache.spark.sql.SQLImplicits.newProductEncoder(SQLImplicits.scala:32)
// at repl.MdocSession$MdocApp0$$anonfun$18.apply(TypedDatasetVsSparkDataset.md:151)
// at repl.MdocSession$MdocApp0$$anonfun$18.apply(TypedDatasetVsSparkDataset.md:151)
In comparison, a TypedDataset
will notify about the encoding problem at compile time:
TypedDataset.create(Seq(MyDate {
val cal = new java.util.GregorianCalendar()
cal.setTime(new java.util.Date(System.currentTimeMillis))
cal
}))
// error: could not find implicit value for parameter encoder: frameless.TypedEncoder[repl.MdocSession.MdocApp0.MyDate]
Aggregate vs Projected columns
Spark's Dataset
do not distinguish between columns created from aggregate operations,
such as summing or averaging, and simple projections/selections.
This is problematic when you start mixing the two.
import org.apache.spark.sql.functions.sum
ds.select(sum($"i"), $"i"*2)
// org.apache.spark.sql.AnalysisException: [MISSING_GROUP_BY] The query does not include a GROUP BY clause. Add GROUP BY or turn it into the window functions using OVER clauses.;
// Aggregate [sum(i#2419L) AS sum(i)#2561L, (i#2419L * cast(2 as bigint)) AS (i * 2)#2562L]
// +- Relation [i#2419L,j#2420] parquet
//
// at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:52)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkValidAggregateExpression$1(CheckAnalysis.scala:474)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$19(CheckAnalysis.scala:490)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$19$adapted(CheckAnalysis.scala:490)
// at scala.collection.Iterator.foreach(Iterator.scala:943)
// at scala.collection.Iterator.foreach$(Iterator.scala:943)
// at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
// at scala.collection.IterableLike.foreach(IterableLike.scala:74)
// at scala.collection.IterableLike.foreach$(IterableLike.scala:73)
// at scala.collection.AbstractIterable.foreach(Iterable.scala:56)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkValidAggregateExpression$1(CheckAnalysis.scala:490)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$19(CheckAnalysis.scala:490)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$19$adapted(CheckAnalysis.scala:490)
// at scala.collection.Iterator.foreach(Iterator.scala:943)
// at scala.collection.Iterator.foreach$(Iterator.scala:943)
// at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
// at scala.collection.IterableLike.foreach(IterableLike.scala:74)
// at scala.collection.IterableLike.foreach$(IterableLike.scala:73)
// at scala.collection.AbstractIterable.foreach(Iterable.scala:56)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkValidAggregateExpression$1(CheckAnalysis.scala:490)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$22(CheckAnalysis.scala:522)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$22$adapted(CheckAnalysis.scala:522)
// at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
// at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
// at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$2(CheckAnalysis.scala:522)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$2$adapted(CheckAnalysis.scala:215)
// at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:244)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis0(CheckAnalysis.scala:215)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis0$(CheckAnalysis.scala:197)
// at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis0(Analyzer.scala:202)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis(CheckAnalysis.scala:193)
// at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.checkAnalysis$(CheckAnalysis.scala:171)
// at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis(Analyzer.scala:202)
// at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$executeAndCheck$1(Analyzer.scala:225)
// at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:330)
// at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:222)
// at org.apache.spark.sql.execution.QueryExecution.$anonfun$analyzed$1(QueryExecution.scala:77)
// at org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:138)
// at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$2(QueryExecution.scala:219)
// at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:546)
// at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:219)
// at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
// at org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:218)
// at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:77)
// at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:74)
// at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:66)
// at org.apache.spark.sql.Dataset$.$anonfun$ofRows$1(Dataset.scala:91)
// at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
// at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:89)
// at org.apache.spark.sql.Dataset.withPlan(Dataset.scala:4352)
// at org.apache.spark.sql.Dataset.select(Dataset.scala:1542)
// at repl.MdocSession$MdocApp0$$anonfun$19.apply(TypedDatasetVsSparkDataset.md:177)
// at repl.MdocSession$MdocApp0$$anonfun$19.apply(TypedDatasetVsSparkDataset.md:177)
In Frameless, mixing the two results in a compilation error.
// To avoid confusing frameless' sum with the standard Spark's sum
import frameless.functions.aggregate.{sum => fsum}
fds.select(fsum(fds('i)))
// error: polymorphic expression cannot be instantiated to expected type;
// found : [Out]frameless.TypedAggregate[repl.MdocSession.MdocApp0.Foo,Out]
// required: frameless.TypedColumn[repl.MdocSession.MdocApp0.Foo,?]
As the error suggests, we expected a TypedColumn
but we got a TypedAggregate
instead.
Here is how you apply an aggregation method in Frameless:
fds.agg(fsum(fds('i))+22).show().run()
// +-----+
// |value|
// +-----+
// | 133|
// +-----+
//
Similarly, mixing projections while aggregating does not make sense, and in Frameless
you get a compilation error.
fds.agg(fsum(fds('i)), fds('i)).show().run()
// error: polymorphic expression cannot be instantiated to expected type;
// found : [A]frameless.TypedColumn[repl.MdocSession.MdocApp0.Foo,A]
// required: frameless.TypedAggregate[repl.MdocSession.MdocApp0.Foo,?]
// fds.agg(fsum(fds('i)), fds('i)).show().run()
// ^^^^^^^