How to convert spark SchemaRDD into RDD of my case class? - sql

In the spark docs it's clear how to create parquet files from RDD of your own case classes; (from the docs)
val people: RDD[Person] = ??? // An RDD of case class objects, from the previous example.
// The RDD is implicitly converted to a SchemaRDD by createSchemaRDD, allowing it to be stored using Parquet.
But not clear how to convert back, really we want a method readParquetFile where we can do:
val people: RDD[Person] = sc.readParquestFile[Person](path)
where those values of the case class are defined are those which are read by the method.

An easy way is to provide your own converter (Row) => CaseClass. This is a bit more manual, but if you know what you are reading it should be quite straightforward.
Here is an example:
import org.apache.spark.sql.SchemaRDD
case class User(data: String, name: String, id: Long)
def sparkSqlToUser(r: Row): Option[User] = {
r match {
case Row(time: String, name: String, id: Long) => Some(User(time,name, id))
case _ => None
val parquetData: SchemaRDD = sqlContext.parquetFile("hdfs://localhost/user/data.parquet")
val caseClassRdd: org.apache.spark.rdd.RDD[User] = parquetData.flatMap(sparkSqlToUser)

The best solution I've come up with that requires the least amount of copy and pasting for new classes is as follows (I'd still like to see another solution though)
First you have to define your case class, and a (partially) reusable factory method
import org.apache.spark.sql.catalyst.expressions
case class MyClass(fooBar: Long, fred: Long)
// Here you want to auto gen these functions using macros or something
object Factories extends {
def longLong[T](fac: (Long, Long) => T)(row: expressions.Row): T =
fac(row(0).asInstanceOf[Long], row(1).asInstanceOf[Long])
Some boiler plate which will already be available
import scala.reflect.runtime.universe._
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext.createSchemaRDD
The magic
import scala.reflect.ClassTag
import org.apache.spark.sql.SchemaRDD
def camelToUnderscores(name: String) =
"[A-Z]".r.replaceAllIn(name, "_" +
def getCaseMethods[T: TypeTag]: List[String] = typeOf[T].members.sorted.collect {
case m: MethodSymbol if m.isCaseAccessor => m
def caseClassToSQLCols[T: TypeTag]: List[String] =
getCaseMethods[T].map(_.split(" ")(1)).map(camelToUnderscores)
def schemaRDDToRDD[T: TypeTag: ClassTag](schemaRDD: SchemaRDD, fac: expressions.Row => T) = {
val tmpName = "tmpTableName" // Maybe should use a random string
sqlContext.sql("SELECT " + caseClassToSQLCols[T].mkString(", ") + " FROM " + tmpName)
Example use
val parquetFile = sqlContext.parquetFile(path)
val normalRDD: RDD[MyClass] =
schemaRDDToRDD[MyClass](parquetFile, Factories.longLong[MyClass](MyClass.apply))
there is a simple method to convert schema rdd to rdd using pyspark in Spark 1.2.1.
sc = SparkContext() ## create SparkContext
srdd = sqlContext.sql(sql)
c = srdd.collect() ## convert rdd to list
rdd = sc.parallelize(c)
there must be similar approach using scala.

Very crufty attempt. Very unconvinced this will have decent performance. Surely there must a macro-based alternative...
import scala.reflect.runtime.universe.typeOf
import scala.reflect.runtime.universe.MethodSymbol
import scala.reflect.runtime.universe.NullaryMethodType
import scala.reflect.runtime.universe.TypeRef
import scala.reflect.runtime.universe.Type
import scala.reflect.runtime.universe.NoType
import scala.reflect.runtime.universe.termNames
import scala.reflect.runtime.universe.runtimeMirror => RowToCaseClass.rowToCaseClass(row.toSeq, typeOf[X], 0))
object RowToCaseClass {
def rowToCaseClass(record: Seq[_], t: Type, depth: Int): Any = {
val fields = t.decls.sorted.collect {
case m: MethodSymbol if m.isCaseAccessor => m
val values = {
case (field, i) =>
field.typeSignature match {
case NullaryMethodType(sig) if sig =:= typeOf[String] => record(i).asInstanceOf[String]
case NullaryMethodType(sig) if sig =:= typeOf[Int] => record(i).asInstanceOf[Int]
case NullaryMethodType(sig) =>
if (sig.baseType(typeOf[Seq[_]].typeSymbol) != NoType) {
sig match {
case TypeRef(_, _, args) =>
record(i).asInstanceOf[Seq[Seq[_]]].map {
r => rowToCaseClass(r, args(0), depth + 1)
} else {
sig match {
case TypeRef(_, u, _) =>
rowToCaseClass(record(i).asInstanceOf[Seq[_]], sig, depth + 1)
val mirror = runtimeMirror(t.getClass.getClassLoader)
val ctor = t.member(termNames.CONSTRUCTOR).asMethod
val klass = t.typeSymbol.asClass
val method = mirror.reflectClass(klass).reflectConstructor(ctor)
method.apply(values: _*)


Creating an object builder with error handling using Arrow - Pattern match multiple Eithers

I have class A:
class A (private var z: String, private var y: String, private var x: Int)
I want to create a failsafe builder for it. The builder should return Either the list of Exceptions (e.g. when values are missing) or the created values. What is the recommended way to create something like this? Or is there a conceptually better approach?
My own approach to it:
sealed class ABuilderException {
object MissingXValue : ABuilderException()
object MissingYValue : ABuilderException()
object MissingZValue : ABuilderException()
import arrow.core.Either
import arrow.core.Option
import arrow.core.none
import arrow.core.some
class ABuilder {
private var x : Option<Int> = none()
private var y : Option<String> = none()
private var z : Option<String> = none()
fun withX(x : Int) : ABuilder {
this.x = x.some();
return this;
fun withY(y : String) : ABuilder {
this.y = y.some();
return this;
fun withZ(z : String) : ABuilder {
this.z = z.some();
return this;
fun build() : Either<A, List<ABuilderException>> {
var xEither = x.toEither { ABuilderException.MissingXValue }
var yEither = y.toEither { ABuilderException.MissingYValue }
var zEither = z.toEither { ABuilderException.MissingZValue }
// If all values are not an exception, create A
// otherwise: Return the list of exceptions
How could I best complete the build code?
I favor a solution that avoids deep nesting (e.g. orElse or similar methods) and avoids repeating values (e.g. by recreating Tuples), because this may lead to typos and makes it harder to add/remove properties later.
First you need to change the signature of build to:
fun build() : Either<List<ABuilderException>, A>
The reason for doing that is because Either is right biased - functions like map, flatMap etc operate on the Right value and are no-op in case the value is Left.
For combining Either values you can use zip:
val e1 = 2.right()
val e2 = 3.right()
// By default it gives you a `Pair` of the two
val c1 = // Either.Right((2, 3))
// Or you can pass a custom combine function
val c2 = { two, three -> two + three } // Either.Right(5)
However there is an issue here, in case of an error (one of them is Left) it will fail fast and give you only the first one.
To accumulate the errors we can use Validated:
val x = none<Int>()
val y = none<String>()
val z = none<String>()
// Validated<String, Int>
val xa = Validated.fromOption(x) { "X is missing" }
// Validated<String, String>
val ya = Validated.fromOption(y) { "Y is missing" }
// Validated<String, String>
val za = Validated.fromOption(z) { "Z is missing" }
) { x, y, z -> TODO() }
Validated, like Either has a zip function for combining values. The difference is that Validated will accumulate the errors. In the lambda you have access to the valid values (Int, String, String) and you can create your valid object.
toValidatedNel() here converts from Validated<String, String> to Validated<Nel<String>, String> where Nel is a list that can NOT be empty. Accumulating errors as a List is common so it's built in.
For more you can check the Error Handling tutorial in the docs.

Why is Kotlin's generateSequence returning one too many items in the example below?

I'm calculating the projection of instants in time based on a cron expression and returning them as a Sequence. Here's the class:
// (package omitted)
import java.time.Instant
import java.time.LocalDate
import java.time.LocalDateTime
import java.time.ZonedDateTime
import java.time.temporal.ChronoUnit
class Recurrence(val cronExpression: String) {
private val cron = CronExpression.parse(cronExpression)
fun instants(
fromInclusive: LocalDate =,
toExclusive: LocalDate = fromInclusive.plusMonths(1)
): Sequence<LocalDateTime> = instants(fromInclusive.atStartOfDay(), toExclusive.atStartOfDay())
fun instants(
fromInclusive: LocalDateTime =,
toExclusive: LocalDateTime = fromInclusive.plusMonths(1)
): Sequence<LocalDateTime> {
return generateSequence( {
if (it.isBefore(toExclusive)) {
} else {
The following test fails because the first assertion is false: the returned list has one extra, unexpected element at the end.
// (package omitted)
import java.time.LocalDate
import java.time.Month
import kotlin.test.Test
import kotlin.test.assertEquals
class RecurrenceTest {
fun testInstants() {
val r = Recurrence("#daily")
val from = LocalDate.of(2021, Month.JANUARY, 1)
val forDays = 31
val instants = r.instants(from, from.plusDays(forDays.toLong())).toList()
assertEquals(forDays, instants.size)
(1..forDays).forEach {
assertEquals(from.plusDays(it.toLong() - 1).atStartOfDay(), instants[it - 1])
If I reimplement by building an ArrayList instead, it works as expected:
// new collection-based methods in Recurrence
fun instantsList(
fromInclusive: LocalDate =,
toExclusive: LocalDate = fromInclusive.plusMonths(1)
): List<LocalDateTime> = instantsList(fromInclusive.atStartOfDay(), toExclusive.atStartOfDay())
fun instantsList(
fromInclusive: LocalDateTime =,
toExclusive: LocalDateTime = fromInclusive.plusMonths(1)
): List<LocalDateTime> {
val list = arrayListOf<LocalDateTime>()
var it =
while (it !== null) {
if (it.isBefore(toExclusive)) {
it =
} else {
return list
The one line to change in the test is to use the new method:
val instants = r.instantsList(from, from.plusDays(forDays.toLong()))
Why is the sequence-based implementation returning me one more element than the list-based one?
If I read your code correctly, in list implementation you check if it.isBefore(toExclusive) and only then you add it to the list. In sequence implementation you do the same check it.isBefore(toExclusive) and then you add next item to the sequence.
Similar with the first item. In list implementation you check if meets the requirement. In sequence implementation you always add it.
Thanks, #broot -- you spotted the issue. Just took another set of eyeballs. Correct sequence implementation is
fun instants(
fromInclusive: LocalDateTime =,
toExclusive: LocalDateTime = fromInclusive.plusMonths(1)
): Sequence<LocalDateTime> {
val seed =
return generateSequence(seed) {
val next =
if (next.isBefore(toExclusive)) {
} else {

Converting SQL Query with Aggregate Function to Relational Algebra Expression in Apache Calcite - No match found for function signature

I'm trying to convert a SQL query to a relational algebra expression using the Apache Calcite SqlToRelConverter.
It works fine for this query (quotes are for ensuring lowercase):
queryToRelationalAlgebraRoot("SELECT \"country\" FROM \"mytable\"")
But on this query it fails:
queryToRelationalAlgebraRoot("SELECT \"country\", SUM(\"salary\") FROM \"mytable\" GROUP BY \"country\"")
with this error:
org.apache.calcite.sql.validate.SqlValidatorException: No match found for function signature SUM(<NUMERIC>)
It seems that somehow the SQL validator doesn't have aggregation functions like sum or count registered.
case class Income(id: Int, salary: Double, country: String)
class SparkDataFrameTable(df: DataFrame) extends AbstractTable {
def getRowType(typeFactory: RelDataTypeFactory): RelDataType = {
val typeList = {
field => field.dataType match {
case t: StringType => typeFactory.createSqlType(SqlTypeName.VARCHAR)
case t: IntegerType => typeFactory.createSqlType(SqlTypeName.INTEGER)
case t: DoubleType => typeFactory.createSqlType(SqlTypeName.DOUBLE)
val fieldNameList = df.schema.fieldNames.toList.asJava
typeFactory.createStructType(typeList, fieldNameList)
object RelationalAlgebra {
def queryToRelationalAlgebraRoot(query: String): RelRoot = {
val sqlParser = SqlParser.create(query)
val sqlParseTree = sqlParser.parseQuery()
val frameworkConfig = Frameworks.newConfigBuilder().build()
val planner = new PlannerImpl(frameworkConfig)
val rootSchema = CalciteSchema.createRootSchema(true, true)
// some sample data for testing
val inc1 = new Income(1, 100000, "USA")
val inc2 = new Income(2, 110000, "USA")
val inc3 = new Income(3, 80000, "Canada")
val spark = SparkSession.builder().master("local").getOrCreate()
import spark.implicits._
val df = Seq(inc1, inc2, inc3).toDF()
rootSchema.add("mytable", new SparkDataFrameTable(df))
val defaultSchema = List[String]().asJava
val calciteConnectionConfigProperties = new Properties()
val calciteConnectionConfigImpl = new CalciteConnectionConfigImpl(calciteConnectionConfigProperties)
val sqlTypeFactoryImpl = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT)
val calciteCatelogReader = new CalciteCatalogReader(rootSchema, defaultSchema, sqlTypeFactoryImpl, calciteConnectionConfigImpl)
val defaultValidator = SqlValidatorUtil.newValidator(new SqlStdOperatorTable(), calciteCatelogReader, sqlTypeFactoryImpl, SqlConformanceEnum.LENIENT)
val relExpressionOptimizationCluster = RelOptCluster.create(new VolcanoPlanner(), new RexBuilder(sqlTypeFactoryImpl))
val sqlToRelConfig = SqlToRelConverter.configBuilder().build()
val sqlToRelConverter = new SqlToRelConverter(planner, defaultValidator, calciteCatelogReader, relExpressionOptimizationCluster, StandardConvertletTable.INSTANCE, sqlToRelConfig)
sqlToRelConverter.convertQuery(sqlParseTree, true, true)
The problem with the code is that new SqlStdOperatorTable() creates a validator which is not initialized. The correct way to use SqlStdOperatorTable is to use SqlStdOperatorTable.instance().
I found the solution after emailing the mailing list. I would like to thank Yuzhao Chen for looking into the question I had and pointing out the problem with my code.
I am not familiar with the api but your SQL needs group by country. And if a tool is to take this output and use it, it will probably require that you name the column too with an alias.

Defining infix operator in Welder

I have the following simplied definition of an addition operation over a field:
import inox._
import inox.trees.{forall => _, _}
import inox.trees.dsl._
object Field {
val element = FreshIdentifier("element")
val zero = FreshIdentifier("zero")
val one = FreshIdentifier("one")
val elementADT = mkSort(element)()(Seq(zero, one))
val zeroADT = mkConstructor(zero)()(Some(element)) {_ => Seq()}
val oneADT = mkConstructor(one)()(Some(element)) {_ => Seq()}
val addID = FreshIdentifier("add")
val addFunction = mkFunDef(addID)("element") { case Seq(eT) =>
val args: Seq[ValDef] = Seq("f1" :: eT, "f2" :: eT)
val retType: Type = eT
val body: Seq[Variable] => Expr = { case Seq(f1,f2) =>
//do the addition for this field
f1 //do something better...
(args, retType, body)
//-------Helper functions for arithmetic operations and zero element of field----------------
implicit class ExprOperands(private val lhs: Expr) extends AnyVal{
def +(rhs: Expr): Expr = E(addID)(T(element)())(lhs, rhs)
I'd like this operation to be used with infix notation and the current solution that I find to do so in Scala is given here. So that's why I'm including the implicit class at the bottom.
Say now I want to use this definition of addition:
import inox._
import inox.trees.{forall => _, _}
import inox.trees.dsl._
import welder._
object Curve{
val affinePoint = FreshIdentifier("affinePoint")
val infinitePoint = FreshIdentifier("infinitePoint")
val finitePoint = FreshIdentifier("finitePoint")
val first = FreshIdentifier("first")
val second = FreshIdentifier("second")
val affinePointADT = mkSort(affinePoint)("F")(Seq(infinitePoint,finitePoint))
val infiniteADT = mkConstructor(infinitePoint)("F")(Some(affinePoint))(_ => Seq())
val finiteADT = mkConstructor(finitePoint)("F")(Some(affinePoint)){ case Seq(fT) =>
Seq(ValDef(first, fT), ValDef(second, fT))
val F = T(Field.element)()
val affine = T(affinePoint)(F)
val infinite = T(infinitePoint)(F)
val finite = T(finitePoint)(F)
val onCurveID = FreshIdentifier("onCurve")
val onCurveFunction = mkFunDef(onCurveID)() { case Seq() =>
val args: Seq[ValDef] = Seq("p" :: affine, "a" :: F, "b" :: F)
val retType: Type = BooleanType
val body: Seq[Variable] => Expr = { case Seq(p,a,b) =>
val x: Expr = p.asInstOf(finite).getField(first)
val y: Expr = p.asInstOf(finite).getField(second)
x === y+y
} else_ {
(args, retType, body)
//---------------------------Registering elements-----------------------------------
val symbols = NoSymbols
val program = InoxProgram(Context.empty, symbols)
val theory = theoryOf(program)
import theory._
val expr = (E(BigInt(1)) + E(BigInt(1))) === E(BigInt(2))
val theorem: Theorem = prove(expr)
This won't compile giving the following error:
at Main$.main(Main.scala:4)
at Main.main(Main.scala)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(
at sun.reflect.DelegatingMethodAccessorImpl.invoke(
at java.lang.reflect.Method.invoke(
Caused by: inox.ast.TypeOps$TypeErrorException: Type error: if (p.isInstanceOf[finitePoint[element]]) {
p.asInstanceOf[finitePoint[element]].first == p.asInstanceOf[finitePoint[element]].second + p.asInstanceOf[finitePoint[element]].second
} else {
}, expected Boolean, found <untyped>
at inox.ast.TypeOps$TypeErrorException$.apply(TypeOps.scala:24)
at inox.ast.TypeOps$class.typeCheck(TypeOps.scala:264)
at inox.ast.SimpleSymbols$SimpleSymbols.typeCheck(SimpleSymbols.scala:12)
at inox.ast.Definitions$AbstractSymbols$$anonfun$ensureWellFormed$2.apply(Definitions.scala:166)
at inox.ast.Definitions$AbstractSymbols$$anonfun$ensureWellFormed$2.apply(Definitions.scala:165)
at scala.collection.TraversableLike$WithFilter$$anonfun$foreach$1.apply(TraversableLike.scala:733)
at scala.collection.immutable.Map$Map2.foreach(Map.scala:137)
at scala.collection.TraversableLike$WithFilter.foreach(TraversableLike.scala:732)
at inox.ast.Definitions$AbstractSymbols$class.ensureWellFormed(Definitions.scala:165)
at inox.ast.SimpleSymbols$SimpleSymbols.ensureWellFormed$lzycompute(SimpleSymbols.scala:12)
at inox.ast.SimpleSymbols$SimpleSymbols.ensureWellFormed(SimpleSymbols.scala:12)
at inox.solvers.unrolling.AbstractUnrollingSolver$class.assertCnstr(UnrollingSolver.scala:129)
at inox.solvers.SolverFactory$$anonfun$getFromName$1$$anon$1.inox$tip$TipDebugger$$super$assertCnstr(SolverFactory.scala:115)
at inox.tip.TipDebugger$class.assertCnstr(TipDebugger.scala:52)
at inox.solvers.SolverFactory$$anonfun$getFromName$1$$anon$1.assertCnstr(SolverFactory.scala:115)
at inox.solvers.SolverFactory$$anonfun$getFromName$1$$anon$1.assertCnstr(SolverFactory.scala:115)
at welder.Solvers$class.prove(Solvers.scala:51)
at welder.package$$anon$1.prove(package.scala:10)
at welder.Solvers$class.prove(Solvers.scala:23)
at welder.package$$anon$1.prove(package.scala:10)
at Curve$.<init>(curve.scala:61)
at Curve$.<clinit>(curve.scala)
at Main$.main(Main.scala:4)
at Main.main(Main.scala)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(
at sun.reflect.DelegatingMethodAccessorImpl.invoke(
at java.lang.reflect.Method.invoke(
In fact, what is happening is that in the expression x === y+y the + is not being applied correctly so that it is untyped. I recall that inside the objects of Welder proofs one cannot define nested objects or classes I don't know if this has to do with it.
Anyways, do I have to forget about using infix notation in my code for Welder or there is a solution to this?
The issue here is that the implicit class you defined is not visible when you create y+y (you would need to import Field._ for it to be visible).
I don't remember exactly how implicit resolution takes place in Scala, so maybe adding import Field._ inside the Curve object will override the + that comes from the inox DSL (that's the one being applied when you write y+y, giving you an arithmetic plus expression that expects integer arguments, hence the type error). Otherwise, you'll unfortunately have ambiguity in the implicit resolution, and I'm not sure it's possible to use the infix + operator in that case without giving up the whole dsl.

Is there any means to serialize custom Transformer in Spark ML Pipeline

I use ML pipeline with various custom UDF-based transformers. What I'm looking for is a way to serialize/deserialize this pipeline.
I serialize the PipelineModel using
However whenever I try to deserialize the pipeline I'm having:
java.lang.ClassNotFoundException: org.sparkexample.DateTransformer
Where is DateTransformer is my custom transformer. Is there any method/interface to implement for proper serialization?
I've found out there is
Interface that might be implemented by my class (DateTransformer extends Transfrormer), however can't find useful example of it.
If you are using Spark 2.x+ then extend your transformer with DefaultParamsWritable
for example
class ProbabilityMaxer extends Transformer with DefaultParamsWritable{
Then create a constructor with a string parameter
def this(_uid: String) {
Finally for a successful read add a companion class
object ProbabilityMaxer extends DefaultParamsReadable[ProbabilityMaxer]
I have this working on my production server. I will add gitlab link to the project later when I upload it
The short answer is you can't, at least not easily.
The devs have gone out of their way to make adding a new transformer/estimator as difficult as possible. Basically everything in is private (except for MLWritable and MLReadable) so there is no way to use any of the utility methods/classes/objects there. There is also (as I'm sure you've already discovered) absolutely no documentation on how this should be done, but hey good code documents itself right?!
Digging through the code in and it seems that you need to override MLWritable.write and The DefaultParamsWriter and DefaultParamsReader which seem to contain the actually save/load implementations are saving and loading a bunch of metadata:
(optionally, extra metadata)
so any implementation would at least need to cover those, and a transformer that doesn't need to learn any model would probably get away with just that. A model that needs to be fitted then also needs to save that data in its implementation of save/write - for instance this is that the LocalLDAModel does ( so the learned model is just saves as a parquet file (it seems)
val data =
.select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration",
As a test I copied everything from that seems to be needed and tested the following transformer which does not do anything useful.
WARNING: this is almost certainly the wrong thing to do and is very likely to break in the future. I sincerely hope I've misunderstood something and someone is going to correct me on how to actually create a transformer that can be serialised/deserialised
this is for spark 1.6.3 and may already be broken, if you're using 2.x
import org.apache.spark.sql.types._
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
import{Identifiable, MLReadable, MLReader, MLWritable, MLWriter}
import org.apache.spark.sql.{SQLContext, DataFrame}
import org.apache.spark.mllib.linalg._
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
object CustomTransform extends DefaultParamsReadable[CustomTransform] {
/* Companion object for deserialisation */
override def load(path: String): CustomTransform = super.load(path)
class CustomTransform(override val uid: String)
extends Transformer with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("customThing"))
def setInputCol(value: String): this.type = set(inputCol, value)
def setOutputCol(value: String): this.type = set(outputCol, value)
def getOutputCol(): String = getOrDefault(outputCol)
val inputCol = new Param[String](this, "inputCol", "input column")
val outputCol = new Param[String](this, "outputCol", "output column")
override def transform(dataset: DataFrame): DataFrame = {
val sqlContext = SQLContext.getOrCreate(SparkContext.getOrCreate())
import sqlContext.implicits._
val outCol = extractParamMap.getOrElse(outputCol, "output")
val inCol = extractParamMap.getOrElse(inputCol, "input")
val transformUDF = udf({ vector: SparseVector => _ * 10 )
dataset.withColumn(outCol, transformUDF(col(inCol)))
override def copy(extra: ParamMap): Transformer = defaultCopy(extra)
override def transformSchema(schema: StructType): StructType = {
val outputFields = schema.fields :+ StructField(extractParamMap.getOrElse(outputCol, "filtered"), new VectorUDT, nullable = false)
Then we need all the utilities from
trait DefaultParamsWritable extends MLWritable { self: Params =>
override def write: MLWriter = new DefaultParamsWriter(this)
trait DefaultParamsReadable[T] extends MLReadable[T] {
override def read: MLReader[T] = new DefaultParamsReader
class DefaultParamsWriter(instance: Params) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
object DefaultParamsWriter {
* Saves metadata + Params to: path + "/metadata"
* - class
* - timestamp
* - sparkVersion
* - uid
* - paramMap
* - (optionally, extra metadata)
* #param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
* #param paramMap If given, this is saved in the "paramMap" field.
* Otherwise, all [[]]s are encoded using
* [[]].
def saveMetadata(
instance: Params,
path: String,
sc: SparkContext,
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): Unit = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
val jsonParams = paramMap.getOrElse(render( { case ParamPair(p, v) => -> parse(p.jsonEncode(v))
val basicMetadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("paramMap" -> jsonParams)
val metadata = extraMetadata match {
case Some(jObject) =>
basicMetadata ~ jObject
case None =>
val metadataPath = new Path(path, "metadata").toString
val metadataJson = compact(render(metadata))
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
class DefaultParamsReader[T] extends MLReader[T] {
override def load(path: String): T = {
val metadata = DefaultParamsReader.loadMetadata(path, sc)
val cls = Class.forName(metadata.className, true, Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader))
val instance =
DefaultParamsReader.getAndSetParams(instance, metadata)
object DefaultParamsReader {
* All info from metadata file.
* #param params paramMap, as a [[JValue]]
* #param metadata All metadata, including the other fields
* #param metadataJson Full metadata file String (for debugging)
case class Metadata(
className: String,
uid: String,
timestamp: Long,
sparkVersion: String,
params: JValue,
metadata: JValue,
metadataJson: String)
* Load metadata from file.
* #param expectedClassName If non empty, this is checked against the loaded metadata.
* #throws IllegalArgumentException if expectedClassName is specified and does not match metadata
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
val metadataPath = new Path(path, "metadata").toString
val metadataStr = sc.textFile(metadataPath, 1).first()
val metadata = parse(metadataStr)
implicit val format = DefaultFormats
val className = (metadata \ "class").extract[String]
val uid = (metadata \ "uid").extract[String]
val timestamp = (metadata \ "timestamp").extract[Long]
val sparkVersion = (metadata \ "sparkVersion").extract[String]
val params = metadata \ "paramMap"
if (expectedClassName.nonEmpty) {
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
s" $expectedClassName but found class name $className")
Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
* Extract Params from metadata, and set them in the instance.
* This works if all Params implement [[]].
def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
implicit val format = DefaultFormats
metadata.params match {
case JObject(pairs) =>
pairs.foreach { case (paramName, jsonValue) =>
val param = instance.getParam(paramName)
val value = param.jsonDecode(compact(render(jsonValue)))
instance.set(param, value)
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
* Load a [[Params]] instance from the given path, and return it.
* This assumes the instance implements [[MLReadable]].
def loadParamsInstance[T](path: String, sc: SparkContext): T = {
val metadata = DefaultParamsReader.loadMetadata(path, sc)
val cls = Class.forName(metadata.className, true, Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader))
With that in place you can use the CustomTransformer in a Pipeline and save/load the pipeline. I tested that fairly quickly in spark shell and it seems to work but certainly isn't pretty.