Defining infix operator in Welder - leon

I have the following simplied definition of an addition operation over a field:
import inox._
import inox.trees.{forall => _, _}
import inox.trees.dsl._
object Field {
val element = FreshIdentifier("element")
val zero = FreshIdentifier("zero")
val one = FreshIdentifier("one")
val elementADT = mkSort(element)()(Seq(zero, one))
val zeroADT = mkConstructor(zero)()(Some(element)) {_ => Seq()}
val oneADT = mkConstructor(one)()(Some(element)) {_ => Seq()}
val addID = FreshIdentifier("add")
val addFunction = mkFunDef(addID)("element") { case Seq(eT) =>
val args: Seq[ValDef] = Seq("f1" :: eT, "f2" :: eT)
val retType: Type = eT
val body: Seq[Variable] => Expr = { case Seq(f1,f2) =>
//do the addition for this field
f1 //do something better...
}
(args, retType, body)
}
//-------Helper functions for arithmetic operations and zero element of field----------------
implicit class ExprOperands(private val lhs: Expr) extends AnyVal{
def +(rhs: Expr): Expr = E(addID)(T(element)())(lhs, rhs)
}
}
I'd like this operation to be used with infix notation and the current solution that I find to do so in Scala is given here. So that's why I'm including the implicit class at the bottom.
Say now I want to use this definition of addition:
import inox._
import inox.trees.{forall => _, _}
import inox.trees.dsl._
import welder._
object Curve{
val affinePoint = FreshIdentifier("affinePoint")
val infinitePoint = FreshIdentifier("infinitePoint")
val finitePoint = FreshIdentifier("finitePoint")
val first = FreshIdentifier("first")
val second = FreshIdentifier("second")
val affinePointADT = mkSort(affinePoint)("F")(Seq(infinitePoint,finitePoint))
val infiniteADT = mkConstructor(infinitePoint)("F")(Some(affinePoint))(_ => Seq())
val finiteADT = mkConstructor(finitePoint)("F")(Some(affinePoint)){ case Seq(fT) =>
Seq(ValDef(first, fT), ValDef(second, fT))
}
val F = T(Field.element)()
val affine = T(affinePoint)(F)
val infinite = T(infinitePoint)(F)
val finite = T(finitePoint)(F)
val onCurveID = FreshIdentifier("onCurve")
val onCurveFunction = mkFunDef(onCurveID)() { case Seq() =>
val args: Seq[ValDef] = Seq("p" :: affine, "a" :: F, "b" :: F)
val retType: Type = BooleanType
val body: Seq[Variable] => Expr = { case Seq(p,a,b) =>
if_(p.isInstOf(finite)){
val x: Expr = p.asInstOf(finite).getField(first)
val y: Expr = p.asInstOf(finite).getField(second)
x === y+y
} else_ {
BooleanLiteral(true)
}
}
(args, retType, body)
}
//---------------------------Registering elements-----------------------------------
val symbols = NoSymbols
.withADTs(Seq(affinePointADT,
infiniteADT,
finiteADT,
Field.zeroADT,
Field.oneADT,
Field.elementADT))
.withFunctions(Seq(Field.addFunction,
onCurveFunction))
val program = InoxProgram(Context.empty, symbols)
val theory = theoryOf(program)
import theory._
val expr = (E(BigInt(1)) + E(BigInt(1))) === E(BigInt(2))
val theorem: Theorem = prove(expr)
}
This won't compile giving the following error:
java.lang.ExceptionInInitializerError
at Main$.main(Main.scala:4)
at Main.main(Main.scala)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
Caused by: inox.ast.TypeOps$TypeErrorException: Type error: if (p.isInstanceOf[finitePoint[element]]) {
p.asInstanceOf[finitePoint[element]].first == p.asInstanceOf[finitePoint[element]].second + p.asInstanceOf[finitePoint[element]].second
} else {
true
}, expected Boolean, found <untyped>
at inox.ast.TypeOps$TypeErrorException$.apply(TypeOps.scala:24)
at inox.ast.TypeOps$class.typeCheck(TypeOps.scala:264)
at inox.ast.SimpleSymbols$SimpleSymbols.typeCheck(SimpleSymbols.scala:12)
at inox.ast.Definitions$AbstractSymbols$$anonfun$ensureWellFormed$2.apply(Definitions.scala:166)
at inox.ast.Definitions$AbstractSymbols$$anonfun$ensureWellFormed$2.apply(Definitions.scala:165)
at scala.collection.TraversableLike$WithFilter$$anonfun$foreach$1.apply(TraversableLike.scala:733)
at scala.collection.immutable.Map$Map2.foreach(Map.scala:137)
at scala.collection.TraversableLike$WithFilter.foreach(TraversableLike.scala:732)
at inox.ast.Definitions$AbstractSymbols$class.ensureWellFormed(Definitions.scala:165)
at inox.ast.SimpleSymbols$SimpleSymbols.ensureWellFormed$lzycompute(SimpleSymbols.scala:12)
at inox.ast.SimpleSymbols$SimpleSymbols.ensureWellFormed(SimpleSymbols.scala:12)
at inox.solvers.unrolling.AbstractUnrollingSolver$class.assertCnstr(UnrollingSolver.scala:129)
at inox.solvers.SolverFactory$$anonfun$getFromName$1$$anon$1.inox$tip$TipDebugger$$super$assertCnstr(SolverFactory.scala:115)
at inox.tip.TipDebugger$class.assertCnstr(TipDebugger.scala:52)
at inox.solvers.SolverFactory$$anonfun$getFromName$1$$anon$1.assertCnstr(SolverFactory.scala:115)
at inox.solvers.SolverFactory$$anonfun$getFromName$1$$anon$1.assertCnstr(SolverFactory.scala:115)
at welder.Solvers$class.prove(Solvers.scala:51)
at welder.package$$anon$1.prove(package.scala:10)
at welder.Solvers$class.prove(Solvers.scala:23)
at welder.package$$anon$1.prove(package.scala:10)
at Curve$.<init>(curve.scala:61)
at Curve$.<clinit>(curve.scala)
at Main$.main(Main.scala:4)
at Main.main(Main.scala)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
In fact, what is happening is that in the expression x === y+y the + is not being applied correctly so that it is untyped. I recall that inside the objects of Welder proofs one cannot define nested objects or classes I don't know if this has to do with it.
Anyways, do I have to forget about using infix notation in my code for Welder or there is a solution to this?

The issue here is that the implicit class you defined is not visible when you create y+y (you would need to import Field._ for it to be visible).
I don't remember exactly how implicit resolution takes place in Scala, so maybe adding import Field._ inside the Curve object will override the + that comes from the inox DSL (that's the one being applied when you write y+y, giving you an arithmetic plus expression that expects integer arguments, hence the type error). Otherwise, you'll unfortunately have ambiguity in the implicit resolution, and I'm not sure it's possible to use the infix + operator in that case without giving up the whole dsl.

Related

Writing an addition visitor function in Kotlin

I'm trying to write a visitor function in Kotlin that adds two integers together. I've been working off of some sample code and I can't figure out what these .value or .visit functions are. It doesn't seem to be declared in the sample code, so I'm unsure how to declare it in my code. Whenever I compile the code, I get an error saying that value is an unresolved reference.
Relevant Kotlin code:
package backend
import org.antlr.v4.runtime.*
import grammar.*
abstract class Data
class IntData(val value: Int): Data() {
override fun toString(): String
= "Int($value)"
}
class Context(): HashMap<String, Data>() {
constructor(parent: Context): this() {
this.putAll(parent)
}
}
abstract class Expr {
abstract fun eval(scope: Context): Data
fun run(program: Expr) {
try {
val data = program.eval(Context())
println("=> ${data}")
} catch(e: Exception) {
println("[err] ${e}")
}
}
}
class IntLiteral(val value: Int): Expr() {
override fun eval(scope:Context): Data
= IntData(value)
}
enum class Op {
Add,
Sub,
Mul,
Div
}
class Arithmetic(
val op: Op,
val left: Expr,
val right: Expr): Expr() {
override fun eval(scope: Context): Data {
val x = (left.eval(scope) as IntData).value
val y = (right.eval(scope) as IntData).value
return IntData(
when(op) {
Op.Add -> x + y
Op.Mul -> x * y
Op.Sub -> x - y
Op.Div -> x / y
}
)
}
}
}
class Compiler: PLBaseVisitor<Expr>() {
val scope = mutableMapOf<String, Expr>()
override fun visitAddExpr(ctx: PLParser.AddExprContext): Expr {
val xValue = this.visit(ctx.x)
val yValue = this.visit(ctx.y)
val result = xValue.value + yValue.value
return IntLiteral(result)
}
}
Relevant Antlr Grammar:
expr : x=expr '+' y=expr # addExpr
| x=expr '-' y=expr # subExpr
| x=expr '*' y=expr # mulExpr
| x=expr '/' y=expr # divExpr
;
Code I'm trying to execute:
val test = """
x=1+2
print(x)
"""
fun parse(source: String): PLParser.ProgramContext {
val input = CharStreams.fromString(source)
val lexer = PLLexer(input)
val tokens = CommonTokenStream(lexer)
val parser = PLParser(tokens)
}
val testTree = parse(source1)
val testTree = parse(source1)
fun execute(program: Expr?) {
if(program == null) {
println("Program is null.")
return
}
try {
val data = program.eval(Context())
println("> ${data}")
} catch(e: Exception) {
println("[err] ${e}")
}
}
execute(testProgram)
Code from sample:
data class NodeValue(val value: Int)
val visitor = object: CalcBaseVisitor<NodeValue>() {
override fun visitAddition(ctx: CalcParser.AdditionContext): NodeValue {
val xValue = this.visit(ctx.x)
val yValue = this.visit(ctx.y)
return NodeValue(xValue.value + yValue.value)
}
override fun visitValue(ctx: CalcParser.ValueContext): NodeValue {
val lexeme = ctx.Number().getText()
return NodeValue(lexeme.toInt())
}
}
You don’t show the code for your program.eval() method.
The eval function would need to create an instance of your Visitor. (You’ve done that and called it visitor).
You also have the root expr node in you program variable.
Now you would have your visitor “visit” that node and save the return value:
val nodeVal = visitor.visit(program)
At that point nodeVal.value will have the result of visiting that expression.
note: since you’re doing the evaluation in your visitor, there’s not really any use for your Arithmetic class (unless you refactor your visitor to use it instead of just doing the math, but I don’t see much value in that as the visitor is already pretty easy to read).

Kotlin DSL - union structure

I am designing a DSL and run into a requirement where I have a variable which could be assigned to different ways. Greatly simplified, I would like to set value property either by an integer or by an expression in String. (The real need is even more complex.)
I would like to write in my DSL:
value = 42
or
value = "6*7"
Behind the scene, the value will be stored in a DynamicValue<Int> structure which contains either an integer or the expression.
class DynamicValue<T>(dv : T?, expr : String) {
val directValue : T? = dv
val script : String? = expr
...
}
I tried several ways (delegate, class, etc), but none of them provided these syntax.
Is there a way to declare this union like structure?
What do you think about the following syntax:
value(42)
value("6*7")
//or
value+=42
value+="6*7"
You can do this with operator functions:
class DynamicValue<T>() {
var dv: T? = null
var expr: String? = null
operator fun invoke(dv : T) {
this.dv = dv
this.expr = null
}
operator fun invoke(expr: String) {
this.dv = null
this.expr = expr
}
operator fun plusAssign(dv : T) {
this.dv = dv
this.expr = null
}
operator fun plusAssign(expr: String) {
this.dv = null
this.expr = expr
}
}
You can't redefine the assign operator in Kotlin, therefor the pure syntax value=42 is not possible.
But I wouldn't go with operator functions, it's to magical. I would do this:
val value = DynamicValue<Int>()
value.simple=42
value.expr="6*7"
class DynamicValue2<T>() {
private var _dv: T? = null
private var _expr: String? = null
var simple: T?
get() = _dv
set(value) {
_dv = value
_expr = null
}
var expr: String?
get() = _expr
set(value) {
_expr = value
_dv = null
}
}
Rene's answer gave me the lead and finally I turned up with this solution.
In this solution I took all my requirements in (the ones I dropped out in my original question) so this became much more complicated than my original question would have required.
My whole requirement was to be able to add static values or scripts (snippets) running on a well guarded context. These script would be stored, and executed later. I wanted to enable the whole power of the IDE when writing the script, but would like to guard my scripts from code injections and help the user to use only the context values the script requires.
The trick I used to achieve this is to enable adding script in kotlin, but before I run the whole DSL script and create the business objects, I convert the script into a string. (This string will be executed later in a guarded, wrapped context by JSR233 engine.) This conversation forced me to tokenize the whole script before execution and search/replace some of the tokens. (The whole tokenizer and converter is rather long and boring, so I won't insert here.)
First approach
What my goal was to be able to write any of this:
myobject {
value = static { 42 } // A static solution
value = static { 6 * 7 } // Even this is possible
value = dynamic{ calc(x, y) } // A pure cotlin solution with IDE support
value = dynamic("""calc(x * x)""") // This is the form I convert the above script to
}
where calc, x and y are defined in the context class:
class SpecialScriptContext : ScriptContextBase() {
val hello = "Hello"
val x = 29
val y = 13
fun calc(x: Int, y: Int) = x + y
fun greet(name: String) = println("$hello $name!")
}
So let's see the solution! First I need a DynamicValue class to hold one of the values:
class DynamicValue<T, C : ScriptContextBase, D: ScriptContextDescriptor<C>>
private constructor(val directValue: T?, val script: String?) {
constructor(value: T?) : this(value, null)
constructor(script: String) : this(null, script)
}
This structure will ensure that exactly one of the options (static, script) will be set. (Don't bother with the C and D type parameters, they are for context-based script support.)
Then I made top level DSL functions to support syntax:
#PlsDsl
fun <T, C : ScriptContextBase, D : ScriptContextDescriptor<C>> static(block: () -> T): DynamicValue<T, C, D>
= DynamicValue<T, C, D>(value = block.invoke())
#PlsDsl
fun <T, C : ScriptContextBase, D : ScriptContextDescriptor<C>> dynamic(s: String): DynamicValue<T, C, D>
= DynamicValue<T, C, D>(script = s)
#PlsDsl
fun <T, C : ScriptContextBase, D : ScriptContextDescriptor<C>> dynamic(block: C.() -> T): DynamicValue<T, C, D> {
throw IllegalStateException("Can't use this format")
}
An explanation to the third form. As I wrote before, I don't want to execute the block of the function. When the script is executed, this form is converted to the string form, so normally this function would never appear in the script when executed. The exception is a sanity warning, which would never be thrown.
Finally added the field to my business object builder:
#PlsDsl
class MyObjectBuilder {
var value: DynamicValue<Int, SpecialScriptContext, SpecialScriptContextDescriptor>? = null
}
Second approach
The previous solution worked but had some flaws: the expression was not associated with the variable it set, neither with the entity the value was set in. With my second approach I solved this problem and removed the need of equal sign and most of the unnecessary curly brackets.
What helped: extension functions, infix functions and sealed classes.
First, I split the two value types into separated classes defined a common ancestor:
sealed class Value<T, C : ScriptContextBase> {
abstract val scriptExecutor: ScriptExecutor
abstract val descriptor: ScriptContextDescriptor<C>
abstract val code: String
abstract fun get(context: C): T?
}
class StaticValue<T, C : ScriptContextBase>(override val code: String,
override val scriptExecutor: ScriptExecutor,
override val descriptor: ScriptContextDescriptor<C>,
val value: T? = null
) : Value<T, C>() {
override fun get(context: C) = value
constructor(oldValue: Value<T, C>, value: T?) : this(oldValue.code, oldValue.scriptExecutor, oldValue.descriptor, value)
}
class DynamicValue<T, C : ScriptContextBase>(override val code: String,
script: String,
override val scriptExecutor: ScriptExecutor,
override val descriptor: ScriptContextDescriptor<C>)
: Value<T, C>() {
constructor(oldValue: Value<T, C>, script: String) : this(oldValue.code, script, oldValue.scriptExecutor, oldValue.descriptor)
private val scriptCache = scriptExecutor.register(descriptor)
val source = script?.replace("\\\"\\\"\\\"", "\"\"\"")
private val compiledScript = scriptCache.register(generateUniqueId(code), source)
override fun get(context: C): T? = compiledScript.execute<T?>(context)
}
Note, that I made the primary constructor internal and created a kind of copy and alter constructor. Then I defined the new functions as extension of the common ancestor and marked them infix:
infix fun <T, C : ScriptContextBase> Value<T, C>.static(value: T?): Value<T, C> = StaticValue(this, value)
infix fun <T, C : ScriptContextBase> Value<T, C>.expr(script: String): Value<T, C> = DynamicValue(this, script)
infix fun <T, C : ScriptContextBase> Value<T, C>.dynamic(block: C.() -> T): Value<T, C> {
throw IllegalStateException("Can't use this format")
}
Using the secondary copy-and-alter constructor allows to inherit the context sensitive values. Finally I initialize the value inside the DSL builder:
#PlsDsl
class MyDslBuilder {
var value: Value<Int, SpecialScriptContext> = StaticValue("pl.value", scriptExecutor, SpecialScriptContextDescriptor)
var value2: Value<Int, SpecialScriptContext> = StaticValue("pl.value2", scriptExecutor, SpecialScriptContextDescriptor)
}
Everything is in place and now I can use it in my script:
myobject {
value static 42
value2 expr "6 * 7"
value2 dynamic { calc(x, y) }
}

"findMethods" doesn't return expected results

I'm trying to implement an analysis (extends DefaultOneStepAnalysis) to construct call graph in CHA algorithms. There are three parts of my code:
1) method "doAnalyze" to return the "BasicReport"
2) method "analyze" to find call edges for each method in the given project
3) class "AnalysisContext" to store the context and methods using in the analysis.
In 3), I use method "callBySignature" to find out cbsMethods of a method same as in "CHACallGraphExtractor" but it doesn't return expected result.
While I use the original OPAL's way to get cbsMethods in Extractor, the result is a set of methods.
Could you please help me to confirm where the problem is and how to solve it?
Thank you very much.
Regards,
Jiang
----Main Part of my code-------------------------------------------------
object CHACGAnalysis extends DefaultOneStepAnalysis {
... ...
override def doAnalyze(
project: Project[URL],
parameters: Seq[String] = List.empty,
isInterrupted: () ⇒ Boolean
): BasicReport = {
... ...
for {
classFile <- project.allProjectClassFiles
method <- classFile.methods
} {
analyze(project, methodToCellCompleter, classFile, method))
}
... ...
}
def analyze(
project: Project[URL],
methodToCellCompleter: Map[(String,Method), CellCompleter[K, Set[Method]]],
classFile: ClassFile,
method: Method
): Unit = {
… …
val context = new AnalysisContext(project, classFile, method)
method.body.get.foreach((pc, instruction) ⇒
instruction.opcode match {
... ...
case INVOKEINTERFACE.opcode ⇒
val INVOKEINTERFACE(declaringClass, name, descriptor) = instruction
context.addCallEdge_VirtualCall(pc, declaringClass, name, descriptor, true,cell1)
... ...
}
… …
}
protected[this] class AnalysisContext(
val project: SomeProject,
val classFile: ClassFile,
val method: Method
) {
val classHierarchy = project.classHierarchy
val cbsIndex = project.get(CallBySignatureResolutionKey)
val statistics = project.get(IntStatisticsKey)
val instantiableClasses = project.get(InstantiableClassesKey)
val cache = new CallGraphCache[MethodSignature, scala.collection.Set[Method]](project)
private[AnalysisContext] def callBySignature(
declaringClassType: ObjectType,
name: String,
descriptor: MethodDescriptor
): Set[Method] = {
val cbsMethods = cbsIndex.findMethods(
name,
descriptor,
declaringClassType
)
cbsMethods
}
def addCallEdge_VirtualCall(
pc: PC,
declaringClassType: ObjectType,
name: String,
descriptor: MethodDescriptor,
isInterfaceInvocation: Boolean = false,
cell1: CellCompleter[K, Set[Method]]
): Unit = {
val cbsCalls =
if (isInterfaceInvocation) {
callBySignature(declaringClassType, name, descriptor)
}
else
Set.empty[Method]
… …
}
… …
}
Finally, I have found the problem is due to "AnalysisMode"
After I resetting AnalysisMode into "CPA“ the question is solved.
I think I should always keep in mind what "AnalysisMode" should be used before I design the algorithm.
Thank you for your concern
Jiang

Null Pointer Exception working with Map (Kotlin)

I have the following class:
class SymbolTable(){
var map = mutableMapOf<String, Entry>()
var kindCounter = mutableMapOf<String, Int>()
fun define(name:String, kind:String, type:String){
if(kindCounter[kind]==0){
kindCounter[kind]=0
}
var index = 1
map[name]= Entry(type, kind, index)
kindCounter[kind]=kindCounter[kind]!!.plus(1)
}
class Entry looks like this:
class Entry(var type:String, var kind:String, var index:Int)
Main:
fun main(args:Array<String>){
var example = SymbolTable()
example.define("ex1", "ex1", "ex1")
example.define("ex2", "ex2", "ex2")
}
When I run the program and try using the "define" function I get the following error:
Exception in thread "main" kotlin.KotlinNullPointerException
at SymbolTable.define(SymbolTable.kt:21)
at SymbolTableKt.main(SymbolTable.kt:31)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at com.intellij.rt.execution.application.AppMain.main(AppMain.java:144)
I assume the problem has something to do with how I create a new symbolTable class, but seeing as Kotlin doesn't have "new" I don't know how to avoid the null pointer exception.
if(kindCounter[kind]==0){
kindCounter[kind]=0
}
This doesn't make much sense: you test if the value is 0, and if it is, you set it to 0. So it's basically a noop.
What you want is to test if the value is null:
if (kindCounter[kind] == null) {
kindCounter[kind] = 0
}
You could also avoid using the dangerous !! operator by saving the value into a variable.
And you should really use val rather than var: all of your fields shouldn't be mutable:
class SymbolTable() {
val map = mutableMapOf<String, Entry>()
val kindCounter = mutableMapOf<String, Int>()
fun define(name: String, kind: String, type: String) {
val count = kindCounter[kind] ?: 0
map[name] = Entry(type, kind, 1)
kindCounter[kind] = count + 1
}
}
class Entry(val type: String, val kind: String, val index: Int)
fun main(args:Array<String>) {
val example = SymbolTable()
example.define("ex1", "ex1", "ex1")
example.define("ex2", "ex2", "ex2")
}

How to convert spark SchemaRDD into RDD of my case class?

In the spark docs it's clear how to create parquet files from RDD of your own case classes; (from the docs)
val people: RDD[Person] = ??? // An RDD of case class objects, from the previous example.
// The RDD is implicitly converted to a SchemaRDD by createSchemaRDD, allowing it to be stored using Parquet.
people.saveAsParquetFile("people.parquet")
But not clear how to convert back, really we want a method readParquetFile where we can do:
val people: RDD[Person] = sc.readParquestFile[Person](path)
where those values of the case class are defined are those which are read by the method.
An easy way is to provide your own converter (Row) => CaseClass. This is a bit more manual, but if you know what you are reading it should be quite straightforward.
Here is an example:
import org.apache.spark.sql.SchemaRDD
case class User(data: String, name: String, id: Long)
def sparkSqlToUser(r: Row): Option[User] = {
r match {
case Row(time: String, name: String, id: Long) => Some(User(time,name, id))
case _ => None
}
}
val parquetData: SchemaRDD = sqlContext.parquetFile("hdfs://localhost/user/data.parquet")
val caseClassRdd: org.apache.spark.rdd.RDD[User] = parquetData.flatMap(sparkSqlToUser)
The best solution I've come up with that requires the least amount of copy and pasting for new classes is as follows (I'd still like to see another solution though)
First you have to define your case class, and a (partially) reusable factory method
import org.apache.spark.sql.catalyst.expressions
case class MyClass(fooBar: Long, fred: Long)
// Here you want to auto gen these functions using macros or something
object Factories extends java.io.Serializable {
def longLong[T](fac: (Long, Long) => T)(row: expressions.Row): T =
fac(row(0).asInstanceOf[Long], row(1).asInstanceOf[Long])
}
Some boiler plate which will already be available
import scala.reflect.runtime.universe._
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext.createSchemaRDD
The magic
import scala.reflect.ClassTag
import org.apache.spark.sql.SchemaRDD
def camelToUnderscores(name: String) =
"[A-Z]".r.replaceAllIn(name, "_" + _.group(0).toLowerCase())
def getCaseMethods[T: TypeTag]: List[String] = typeOf[T].members.sorted.collect {
case m: MethodSymbol if m.isCaseAccessor => m
}.toList.map(_.toString)
def caseClassToSQLCols[T: TypeTag]: List[String] =
getCaseMethods[T].map(_.split(" ")(1)).map(camelToUnderscores)
def schemaRDDToRDD[T: TypeTag: ClassTag](schemaRDD: SchemaRDD, fac: expressions.Row => T) = {
val tmpName = "tmpTableName" // Maybe should use a random string
schemaRDD.registerAsTable(tmpName)
sqlContext.sql("SELECT " + caseClassToSQLCols[T].mkString(", ") + " FROM " + tmpName)
.map(fac)
}
Example use
val parquetFile = sqlContext.parquetFile(path)
val normalRDD: RDD[MyClass] =
schemaRDDToRDD[MyClass](parquetFile, Factories.longLong[MyClass](MyClass.apply))
See also:
http://apache-spark-user-list.1001560.n3.nabble.com/Spark-SQL-Convert-SchemaRDD-back-to-RDD-td9071.html
Though I failed to find any example or documentation by following the JIRA link.
there is a simple method to convert schema rdd to rdd using pyspark in Spark 1.2.1.
sc = SparkContext() ## create SparkContext
srdd = sqlContext.sql(sql)
c = srdd.collect() ## convert rdd to list
rdd = sc.parallelize(c)
there must be similar approach using scala.
Very crufty attempt. Very unconvinced this will have decent performance. Surely there must a macro-based alternative...
import scala.reflect.runtime.universe.typeOf
import scala.reflect.runtime.universe.MethodSymbol
import scala.reflect.runtime.universe.NullaryMethodType
import scala.reflect.runtime.universe.TypeRef
import scala.reflect.runtime.universe.Type
import scala.reflect.runtime.universe.NoType
import scala.reflect.runtime.universe.termNames
import scala.reflect.runtime.universe.runtimeMirror
schemaRdd.map(row => RowToCaseClass.rowToCaseClass(row.toSeq, typeOf[X], 0))
object RowToCaseClass {
// http://dcsobral.blogspot.com/2012/08/json-serialization-with-reflection-in.html
def rowToCaseClass(record: Seq[_], t: Type, depth: Int): Any = {
val fields = t.decls.sorted.collect {
case m: MethodSymbol if m.isCaseAccessor => m
}
val values = fields.zipWithIndex.map {
case (field, i) =>
field.typeSignature match {
case NullaryMethodType(sig) if sig =:= typeOf[String] => record(i).asInstanceOf[String]
case NullaryMethodType(sig) if sig =:= typeOf[Int] => record(i).asInstanceOf[Int]
case NullaryMethodType(sig) =>
if (sig.baseType(typeOf[Seq[_]].typeSymbol) != NoType) {
sig match {
case TypeRef(_, _, args) =>
record(i).asInstanceOf[Seq[Seq[_]]].map {
r => rowToCaseClass(r, args(0), depth + 1)
}.toSeq
}
} else {
sig match {
case TypeRef(_, u, _) =>
rowToCaseClass(record(i).asInstanceOf[Seq[_]], sig, depth + 1)
}
}
}
}.asInstanceOf[Seq[Object]]
val mirror = runtimeMirror(t.getClass.getClassLoader)
val ctor = t.member(termNames.CONSTRUCTOR).asMethod
val klass = t.typeSymbol.asClass
val method = mirror.reflectClass(klass).reflectConstructor(ctor)
method.apply(values: _*)
}
}