How to mimic `tableswitch` using `MethodHandle`? - jvm

Context: I've been benchmarking the difference between using invokedynamic and manually generating bytecode (this is in the context of deciding whether a compiler targeting the JVM should emit more verbose "traditional" bytecode or just an invokedynamic call with a clever bootstrap method). In doing this, it has been pretty straightforward to map bytecode into MethodHandles combinators that are at least as fast, with the exception of tableswitch.
Question: Is there a trick to mimic tableswitch using MethodHandle? I tried mimicking it with a jump table: using a constant MethodHandle[], indexing into that with arrayElementGetter, then calling the found handle with MethodHandles.invoker. However, that ended up being around 50% slower than the original bytecode when I ran it through JMH.
Here's the code for producing the method handle:
private static MethodHandle makeProductElement(Class<?> receiverClass, List<MethodHandle> getters) {
MethodHandle[] boxedGetters = getters
.stream()
.map(getter -> getter.asType(getter.type().changeReturnType(java.lang.Object.class)))
.toArray(MethodHandle[]::new);
MethodHandle getGetter = MethodHandles // (I)H
.arrayElementGetter(MethodHandle[].class)
.bindTo(boxedGetters);
MethodHandle invokeGetter = MethodHandles.permuteArguments( // (RH)O
MethodHandles.invoker(MethodType.methodType(java.lang.Object.class, receiverClass)),
MethodType.methodType(java.lang.Object.class, receiverClass, MethodHandle.class),
1,
0
);
return MethodHandles.filterArguments(invokeGetter, 1, getGetter);
}
Here's the initial bytecode (which I'm trying to replace with one invokedynamic call)
public java.lang.Object productElement(int);
descriptor: (I)Ljava/lang/Object;
flags: (0x0001) ACC_PUBLIC
Code:
stack=3, locals=3, args_size=2
0: iload_1
1: istore_2
2: iload_2
3: tableswitch { // 0 to 2
0: 28
1: 38
2: 45
default: 55
}
28: aload_0
29: invokevirtual #62 // Method i:()I
32: invokestatic #81 // Method java/lang/Integer.valueOf:(I)Ljava/lang/Integer;
35: goto 67
38: aload_0
39: invokevirtual #65 // Method s:()Ljava/lang/String;
42: goto 67
45: aload_0
46: invokevirtual #68 // Method l:()J
49: invokestatic #85 // Method java/lang/Long.valueOf:(J)Ljava/lang/Long;
52: goto 67
55: new #87 // class java/lang/IndexOutOfBoundsException
58: dup
59: iload_1
60: invokestatic #93 // Method java/lang/Integer.toString:(I)Ljava/lang/String;
63: invokespecial #96 // Method java/lang/IndexOutOfBoundsException."<init>":(Ljava/lang/String;)V
66: athrow
67: areturn

The good thing about invokedynamic is that it allows to postpone the decision, how to implement the operation to the actual runtime. This is the trick behind LambdaMetafactory or StringConcatFactory which may return composed method handles, like in your example code, or dynamically generated code, at the particular implementation’s discretion.
There’s even a combined approach possible, generate classes which you compose to an operation, e.g. settling on the already existing LambdaMetafactory:
private static MethodHandle makeProductElement(
MethodHandles.Lookup lookup, Class<?> receiverClass, List<MethodHandle> getters)
throws Throwable {
Function[] boxedGetters = new Function[getters.size()];
MethodType factory = MethodType.methodType(Function.class);
for(int ix = 0; ix < boxedGetters.length; ix++) {
MethodHandle mh = getters.get(ix);
MethodType actual = mh.type().wrap(), generic = actual.erase();
boxedGetters[ix] = (Function)LambdaMetafactory.metafactory(lookup,
"apply", factory, generic, mh, actual).getTarget().invokeExact();
}
Object switcher = new Object() {
final Object get(Object receiver, int index) {
return boxedGetters[index].apply(receiver);
}
};
return lookup.bind(switcher, "get",
MethodType.methodType(Object.class, Object.class, int.class))
.asType(MethodType.methodType(Object.class, receiverClass, int.class));
}
This uses the LambdaMetafactory to generate a Function instance for each getter, similar to equivalent method references. Then, an actual class calling the right Function’s apply method is instantiated and a method handle to its get method returned.
This is a similar composition as your method handles, but with the reference implementation, no handles but fully materialized classes are used. I’d expect the composed handles and this approach to converge to the same performance for a very large number of invocations, but the materialized classes having a headstart for a medium number of invocations.
I added a first parameter MethodHandles.Lookup lookup which should be the lookup object received by the bootstrap method for the invokedynamic instruction. If used that way, the generated functions can access all methods the same way as the code containing the invokedynamic instruction, including private methods of that class.
Alternatively, you can generate a class containing a real switch instruction yourself. Using the ASM library, it may look like:
private static MethodHandle makeProductElement(
MethodHandles.Lookup lookup, Class<?> receiverClass, List<MethodHandle> getters)
throws ReflectiveOperationException {
ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
cw.visit(V1_8, ACC_INTERFACE|ACC_ABSTRACT,
lookup.lookupClass().getName().replace('.', '/')+"$Switch", null,
"java/lang/Object", null);
MethodType type = MethodType.methodType(Object.class, receiverClass, int.class);
MethodVisitor mv = cw.visitMethod(ACC_STATIC|ACC_PUBLIC, "get",
type.toMethodDescriptorString(), null, null);
mv.visitCode();
Label defaultCase = new Label();
Label[] cases = new Label[getters.size()];
for(int ix = 0; ix < cases.length; ix++) cases[ix] = new Label();
mv.visitVarInsn(ALOAD, 0);
mv.visitVarInsn(ILOAD, 1);
mv.visitTableSwitchInsn(0, cases.length - 1, defaultCase, cases);
String owner = receiverClass.getName().replace('.', '/');
for(int ix = 0; ix < cases.length; ix++) {
mv.visitLabel(cases[ix]);
MethodHandle mh = getters.get(ix);
mv.visitMethodInsn(INVOKEVIRTUAL, owner, lookup.revealDirect(mh).getName(),
mh.type().dropParameterTypes(0, 1).toMethodDescriptorString(), false);
if(mh.type().returnType().isPrimitive()) {
Class<?> boxed = mh.type().wrap().returnType();
MethodType box = MethodType.methodType(boxed, mh.type().returnType());
mv.visitMethodInsn(INVOKESTATIC, boxed.getName().replace('.', '/'),
"valueOf", box.toMethodDescriptorString(), false);
}
mv.visitInsn(ARETURN);
}
mv.visitLabel(defaultCase);
mv.visitTypeInsn(NEW, "java/lang/IndexOutOfBoundsException");
mv.visitInsn(DUP);
mv.visitVarInsn(ILOAD, 1);
mv.visitMethodInsn(INVOKESTATIC, "java/lang/String",
"valueOf", "(I)Ljava/lang/String;", false);
mv.visitMethodInsn(INVOKESPECIAL, "java/lang/IndexOutOfBoundsException",
"<init>", "(Ljava/lang/String;)V", false);
mv.visitInsn(ATHROW);
mv.visitMaxs(-1, -1);
mv.visitEnd();
cw.visitEnd();
lookup = lookup.defineHiddenClass(
cw.toByteArray(), true, MethodHandles.Lookup.ClassOption.NESTMATE);
return lookup.findStatic(lookup.lookupClass(), "get", type);
}
This generates a new class with a static method containing the tableswitch instruction and the invocations (as well as the boxing conversions we now have to do ourselves). Also, it has the necessary code to create and throw an exception for out-of-bounds values. After generating the class, it returns a handle to that static method.

I don't know of your timeline. But it is likely there will be a MethodHandles.tableSwitch operation in Java 17. It is currently being integrated via https://github.com/openjdk/jdk/pull/3401/
Some more discussion about it here:
https://mail.openjdk.java.net/pipermail/core-libs-dev/2021-April/076105.html

The things is, tableswitch isn't always compiled to a jump table. For a small number of labels, like in your example, it's likely to act as a binary search. Thus using a tree of regular "if-then" MethodHandles will be the closest equivalent.

Related

Generate a tree of structs with testing/quick, respecting invariants

I have a tree of structs which I'd like to test using testing/quick, but constraining it to within my invariants.
This example code works:
var rnd = rand.New(rand.NewSource(time.Now().UnixNano()))
type X struct {
HasChildren bool
Children []*X
}
func TestSomething(t *testing.T) {
x, _ := quick.Value(reflect.TypeOf(X{}), rnd)
_ = x
// test some stuff here
}
But we hold HasChildren = true whenever len(Children) > 0 as an invariant, so it'd be better to ensure that whatever quick.Value() generates respects that (rather than finding "bugs" that don't actually exist).
I figured I could define a Generate function which uses quick.Value() to populate all the variable members:
func (X) Generate(rand *rand.Rand, size int) reflect.Value {
x := X{}
throwaway, _ := quick.Value(reflect.TypeOf([]*X{}), rand)
x.Children = throwaway.Interface().([]*X)
if len(x.Children) > 0 {
x.HasChildren = true
} else {
x.HasChildren = false
}
return reflect.ValueOf(x)
}
But this is panicking:
panic: value method main.X.Generate called using nil *X pointer [recovered]
And when I change Children from []*X to []X, it dies with a stack overflow.
The documentation is very thin on examples, and I'm finding almost nothing in web searches either.
How can this be done?
Looking at the testing/quick source code it seems that you can't create recursive custom generators and at the same time reuse the quick library facilities to generate the array part of the struct, because the size parameter, that is designed to limit the number of recursive calls, cannot be passed back into quick.Value(...)
https://golang.org/src/testing/quick/quick.go (see around line 50)
in your case this lead to an infinite tree that quickly "explodes" with 1..50 leafs at each level (that's the reason for the stack overflow).
If the function quick.sizedValue() had been public we could have used it to accomplish your task, but unfortunately this is not the case.
BTW since HasChildren is an invariant, can't you simply make it a struct method?
type X struct {
Children []*X
}
func (me *X) HasChildren() bool {
return len(me.Children) > 0
}
func main() {
.... generate X ....
if x.HasChildren() {
.....
}
}

Array/List iteration without extra object allocations

I'm working on a game written in Kotlin and was looking into improving GC churn. One of the major sources of churn are for-loops called in the main game/rendering loops that result in the allocation of iterators.
Turning to the documentation, I found this paragraph:
A for loop over an array is compiled to an index-based loop that does not create an iterator object.
If you want to iterate through an array or a list with an index, you can do it this way:
for (i in array.indices)
print(array[i])
Note that this “iteration through a range” is compiled down to optimal implementation with no extra objects created.
https://kotlinlang.org/docs/reference/control-flow.html#for-loops
Is this really true? To verify, I took this simple Kotlin program and inspected the generated byte code:
fun main(args: Array<String>) {
val arr = arrayOf(1, 2, 3)
for (i in arr.indices) {
println(arr[i])
}
}
According to the quote above, this should not result in any objects allocated, but get compiled down to a good old pre-Java-5 style for-loop. However, what I got was this:
41: aload_1
42: checkcast #23 // class "[Ljava/lang/Object;"
45: invokestatic #31 // Method kotlin/collections/ArraysKt.getIndices:([Ljava/lang/Object;)Lkotlin/ranges/IntRange;
48: dup
49: invokevirtual #37 // Method kotlin/ranges/IntRange.getFirst:()I
52: istore_2
53: invokevirtual #40 // Method kotlin/ranges/IntRange.getLast:()I
56: istore_3
57: iload_2
58: iload_3
59: if_icmpgt 93
This looks to me as if a method called getIndices is called that allocates a temporary IntRange object to back up bounds checking in this loop. How is this an "optimal implementation" with "no extra objects created", or am I missing something?
UPDATE:
So, after toying around a bit more and looking at the answers, the following appears to be true for Kotlin 1.0.2:
Arrays:
for (i in array.indices): range allocation
for (i in 0..array.size): no allocation
for (el in array): no allocation
array.forEach: no allocation
Collections:
for (i in coll.indices) range allocation
for (i in 0..coll.size): no allocation
for (el in coll): iterator allocation
coll.forEach: iterator allocation
To iterate an array without allocating extra objects you can use one of the following ways.
for-loop
for (e in arr) {
println(e)
}
forEach extension
arr.forEach {
println(it)
}
forEachIndexed extension, if you need to know index of each element
arr.forEachIndexed { index, e ->
println("$e at $index")
}
As far as I know the only allocation-less way to define a for loop is
for (i in 0..count - 1)
All other forms lead to either a Range allocation or an Iterator allocation. Unfortunately, you cannot even define an effective reverse for loop.
Here is an example of preparing a list and iterate with index and value.
val list = arrayListOf("1", "11", "111")
for ((index, value) in list.withIndex()) {
println("$index: $value")
}
Output:
0:1
1:11
2:111
Also, following code works similar,
val simplearray = arrayOf(1, 2, 3, 4, 5)
for ((index, value) in simplearray.withIndex()) {
println("$index: $value")
}

Singleton implementation on perl6

I looked at the following code on Rosetta code http://rosettacode.org/wiki/Singleton#Perl_6
which implements Singleton in Perl6
class Singleton {
has Int $.x is rw;
# We create a lexical variable in the class block that holds our single instance.
my Singleton $instance = Singleton.bless; # You can add initialization arguments here.
method new {!!!} # Singleton.new dies.
method instance { $instance; }
}
my $a=Singleton.bless(x => 1);
my $b=Singleton.bless(x=> 2);
say $a.x;
say $b.x;
#result
# 1
# 2
but it seems using this implementation i can create tow instances of the same class using bless see example above ,
is there an option to prevent the implemention to only one instance of the same class ?
Perl prides itself on providing many ways to do things leaving you to pick the one that suits your tastes and the application at hand. I say that to highlight that this is one simple but solid, hopefully self-explanatory way - I'm not putting it forward as "the best" because that depends on your circumstances.
#!/usr/bin/env perl6
class Scoreboard {
has Str $.home-team ;
has Str $.away-team ;
has Int $.home-score = 0 ;
has Int $.away-score = 0 ;
my Scoreboard $instance;
method new(*%named) {
return $instance //= self.bless(|%named) ;
}
multi method goal($team where * eq $!home-team, Int :$points = 6) {
$!home-score += $points
}
multi method goal($team where * eq $!away-team, Int :$points = 6) {
$!away-score += $points
}
method Str {
"At this vital stage of the game " ~
do given $!home-score <=> $!away-score {
when More {
"$!home-team are leading $!away-team, $!home-score points to $!away-score"
}
when Less {
"$!home-team are behind $!away-team, $!home-score points to $!away-score"
}
default {
"the scores are level! $!home-score apeice!"
}
}
}
}
my $home-team = "The Rabid Rabbits";
my $away-team = "The Turquoise Turtles"; # Go them turtles!
my $scoreboard = Scoreboard.new( :$home-team , :$away-team );
$scoreboard.goal($home-team, :4points) ;
say "$scoreboard";
$scoreboard.goal($away-team) ;
say "$scoreboard";
my $evil_second_scoreboard = Scoreboard.new;
$evil_second_scoreboard.goal($home-team, :2points) ;
say "$evil_second_scoreboard";
This produces;
At this vital stage of the game The Rabid Rabbits are leading The Turquoise Turtles, 4 points to 0
At this vital stage of the game The Rabid Rabbits are behind The Turquoise Turtles, 4 points to 6
At this vital stage of the game the scores are level! 6 apeice!
This overrides the default new (normally supplied by class Mu) and keep a reference to ourself (ie this object) in private class data. For private class data, we use a lexically scoped scalar declared with my. The // is the operator form of .defined. So, on the first run, we call bless which allocates the object and initialize the attributes, and then assign it to $instance. In subsequent calls to new, $instance is defined and is immediately returned.
If you want to prevent someone calling bless directly, you can add;
method bless(|) {
nextsame unless $instance;
fail "bless called on singleton Scoreboard"
}
which will ensure that only the first call will work.

How do I write a generic memoize function?

I'm writing a function to find triangle numbers and the natural way to write it is recursively:
function triangle (x)
if x == 0 then return 0 end
return x+triangle(x-1)
end
But attempting to calculate the first 100,000 triangle numbers fails with a stack overflow after a while. This is an ideal function to memoize, but I want a solution that will memoize any function I pass to it.
Mathematica has a particularly slick way to do memoization, relying on the fact that hashes and function calls use the same syntax:
triangle[0] = 0;
triangle[x_] := triangle[x] = x + triangle[x-1]
That's it. It works because the rules for pattern-matching function calls are such that it always uses a more specific definition before a more general definition.
Of course, as has been pointed out, this example has a closed-form solution: triangle[x_] := x*(x+1)/2. Fibonacci numbers are the classic example of how adding memoization gives a drastic speedup:
fib[0] = 1;
fib[1] = 1;
fib[n_] := fib[n] = fib[n-1] + fib[n-2]
Although that too has a closed-form equivalent, albeit messier: http://mathworld.wolfram.com/FibonacciNumber.html
I disagree with the person who suggested this was inappropriate for memoization because you could "just use a loop". The point of memoization is that any repeat function calls are O(1) time. That's a lot better than O(n). In fact, you could even concoct a scenario where the memoized implementation has better performance than the closed-form implementation!
You're also asking the wrong question for your original problem ;)
This is a better way for that case:
triangle(n) = n * (n - 1) / 2
Furthermore, supposing the formula didn't have such a neat solution, memoisation would still be a poor approach here. You'd be better off just writing a simple loop in this case. See this answer for a fuller discussion.
I bet something like this should work with variable argument lists in Lua:
local function varg_tostring(...)
local s = select(1, ...)
for n = 2, select('#', ...) do
s = s..","..select(n,...)
end
return s
end
local function memoize(f)
local cache = {}
return function (...)
local al = varg_tostring(...)
if cache[al] then
return cache[al]
else
local y = f(...)
cache[al] = y
return y
end
end
end
You could probably also do something clever with a metatables with __tostring so that the argument list could just be converted with a tostring(). Oh the possibilities.
In C# 3.0 - for recursive functions, you can do something like:
public static class Helpers
{
public static Func<A, R> Memoize<A, R>(this Func<A, Func<A,R>, R> f)
{
var map = new Dictionary<A, R>();
Func<A, R> self = null;
self = (a) =>
{
R value;
if (map.TryGetValue(a, out value))
return value;
value = f(a, self);
map.Add(a, value);
return value;
};
return self;
}
}
Then you can create a memoized Fibonacci function like this:
var memoized_fib = Helpers.Memoize<int, int>((n,fib) => n > 1 ? fib(n - 1) + fib(n - 2) : n);
Console.WriteLine(memoized_fib(40));
In Scala (untested):
def memoize[A, B](f: (A)=>B) = {
var cache = Map[A, B]()
{ x: A =>
if (cache contains x) cache(x) else {
val back = f(x)
cache += (x -> back)
back
}
}
}
Note that this only works for functions of arity 1, but with currying you could make it work. The more subtle problem is that memoize(f) != memoize(f) for any function f. One very sneaky way to fix this would be something like the following:
val correctMem = memoize(memoize _)
I don't think that this will compile, but it does illustrate the idea.
Update: Commenters have pointed out that memoization is a good way to optimize recursion. Admittedly, I hadn't considered this before, since I generally work in a language (C#) where generalized memoization isn't so trivial to build. Take the post below with that grain of salt in mind.
I think Luke likely has the most appropriate solution to this problem, but memoization is not generally the solution to any issue of stack overflow.
Stack overflow usually is caused by recursion going deeper than the platform can handle. Languages sometimes support "tail recursion", which re-uses the context of the current call, rather than creating a new context for the recursive call. But a lot of mainstream languages/platforms don't support this. C# has no inherent support for tail-recursion, for example. The 64-bit version of the .NET JITter can apply it as an optimization at the IL level, which is all but useless if you need to support 32-bit platforms.
If your language doesn't support tail recursion, your best option for avoiding stack overflows is either to convert to an explicit loop (much less elegant, but sometimes necessary), or find a non-iterative algorithm such as Luke provided for this problem.
function memoize (f)
local cache = {}
return function (x)
if cache[x] then
return cache[x]
else
local y = f(x)
cache[x] = y
return y
end
end
end
triangle = memoize(triangle);
Note that to avoid a stack overflow, triangle would still need to be seeded.
Here's something that works without converting the arguments to strings.
The only caveat is that it can't handle a nil argument. But the accepted solution can't distinguish the value nil from the string "nil", so that's probably OK.
local function m(f)
local t = { }
local function mf(x, ...) -- memoized f
assert(x ~= nil, 'nil passed to memoized function')
if select('#', ...) > 0 then
t[x] = t[x] or m(function(...) return f(x, ...) end)
return t[x](...)
else
t[x] = t[x] or f(x)
assert(t[x] ~= nil, 'memoized function returns nil')
return t[x]
end
end
return mf
end
I've been inspired by this question to implement (yet another) flexible memoize function in Lua.
https://github.com/kikito/memoize.lua
Main advantages:
Accepts a variable number of arguments
Doesn't use tostring; instead, it organizes the cache in a tree structure, using the parameters to traverse it.
Works just fine with functions that return multiple values.
Pasting the code here as reference:
local globalCache = {}
local function getFromCache(cache, args)
local node = cache
for i=1, #args do
if not node.children then return {} end
node = node.children[args[i]]
if not node then return {} end
end
return node.results
end
local function insertInCache(cache, args, results)
local arg
local node = cache
for i=1, #args do
arg = args[i]
node.children = node.children or {}
node.children[arg] = node.children[arg] or {}
node = node.children[arg]
end
node.results = results
end
-- public function
local function memoize(f)
globalCache[f] = { results = {} }
return function (...)
local results = getFromCache( globalCache[f], {...} )
if #results == 0 then
results = { f(...) }
insertInCache(globalCache[f], {...}, results)
end
return unpack(results)
end
end
return memoize
Here is a generic C# 3.0 implementation, if it could help :
public static class Memoization
{
public static Func<T, TResult> Memoize<T, TResult>(this Func<T, TResult> function)
{
var cache = new Dictionary<T, TResult>();
var nullCache = default(TResult);
var isNullCacheSet = false;
return parameter =>
{
TResult value;
if (parameter == null && isNullCacheSet)
{
return nullCache;
}
if (parameter == null)
{
nullCache = function(parameter);
isNullCacheSet = true;
return nullCache;
}
if (cache.TryGetValue(parameter, out value))
{
return value;
}
value = function(parameter);
cache.Add(parameter, value);
return value;
};
}
}
(Quoted from a french blog article)
In the vein of posting memoization in different languages, i'd like to respond to #onebyone.livejournal.com with a non-language-changing C++ example.
First, a memoizer for single arg functions:
template <class Result, class Arg, class ResultStore = std::map<Arg, Result> >
class memoizer1{
public:
template <class F>
const Result& operator()(F f, const Arg& a){
typename ResultStore::const_iterator it = memo_.find(a);
if(it == memo_.end()) {
it = memo_.insert(make_pair(a, f(a))).first;
}
return it->second;
}
private:
ResultStore memo_;
};
Just create an instance of the memoizer, feed it your function and argument. Just make sure not to share the same memo between two different functions (but you can share it between different implementations of the same function).
Next, a driver functon, and an implementation. only the driver function need be public
int fib(int); // driver
int fib_(int); // implementation
Implemented:
int fib_(int n){
++total_ops;
if(n == 0 || n == 1)
return 1;
else
return fib(n-1) + fib(n-2);
}
And the driver, to memoize
int fib(int n) {
static memoizer1<int,int> memo;
return memo(fib_, n);
}
Permalink showing output on codepad.org. Number of calls is measured to verify correctness. (insert unit test here...)
This only memoizes one input functions. Generalizing for multiple args or varying arguments left as an exercise for the reader.
In Perl generic memoization is easy to get. The Memoize module is part of the perl core and is highly reliable, flexible, and easy-to-use.
The example from it's manpage:
# This is the documentation for Memoize 1.01
use Memoize;
memoize('slow_function');
slow_function(arguments); # Is faster than it was before
You can add, remove, and customize memoization of functions at run time! You can provide callbacks for custom memento computation.
Memoize.pm even has facilities for making the memento cache persistent, so it does not need to be re-filled on each invocation of your program!
Here's the documentation: http://perldoc.perl.org/5.8.8/Memoize.html
Extending the idea, it's also possible to memoize functions with two input parameters:
function memoize2 (f)
local cache = {}
return function (x, y)
if cache[x..','..y] then
return cache[x..','..y]
else
local z = f(x,y)
cache[x..','..y] = z
return z
end
end
end
Notice that parameter order matters in the caching algorithm, so if parameter order doesn't matter in the functions to be memoized the odds of getting a cache hit would be increased by sorting the parameters before checking the cache.
But it's important to note that some functions can't be profitably memoized. I wrote memoize2 to see if the recursive Euclidean algorithm for finding the greatest common divisor could be sped up.
function gcd (a, b)
if b == 0 then return a end
return gcd(b, a%b)
end
As it turns out, gcd doesn't respond well to memoization. The calculation it does is far less expensive than the caching algorithm. Ever for large numbers, it terminates fairly quickly. After a while, the cache grows very large. This algorithm is probably as fast as it can be.
Recursion isn't necessary. The nth triangle number is n(n-1)/2, so...
public int triangle(final int n){
return n * (n - 1) / 2;
}
Please don't recurse this. Either use the x*(x+1)/2 formula or simply iterate the values and memoize as you go.
int[] memo = new int[n+1];
int sum = 0;
for(int i = 0; i <= n; ++i)
{
sum+=i;
memo[i] = sum;
}
return memo[n];

Expression Evaluation and Tree Walking using polymorphism? (ala Steve Yegge)

This morning, I was reading Steve Yegge's: When Polymorphism Fails, when I came across a question that a co-worker of his used to ask potential employees when they came for their interview at Amazon.
As an example of polymorphism in
action, let's look at the classic
"eval" interview question, which (as
far as I know) was brought to Amazon
by Ron Braunstein. The question is
quite a rich one, as it manages to
probe a wide variety of important
skills: OOP design, recursion, binary
trees, polymorphism and runtime
typing, general coding skills, and (if
you want to make it extra hard)
parsing theory.
At some point, the candidate hopefully
realizes that you can represent an
arithmetic expression as a binary
tree, assuming you're only using
binary operators such as "+", "-",
"*", "/". The leaf nodes are all
numbers, and the internal nodes are
all operators. Evaluating the
expression means walking the tree. If
the candidate doesn't realize this,
you can gently lead them to it, or if
necessary, just tell them.
Even if you tell them, it's still an
interesting problem.
The first half of the question, which
some people (whose names I will
protect to my dying breath, but their
initials are Willie Lewis) feel is a
Job Requirement If You Want To Call
Yourself A Developer And Work At
Amazon, is actually kinda hard. The
question is: how do you go from an
arithmetic expression (e.g. in a
string) such as "2 + (2)" to an
expression tree. We may have an ADJ
challenge on this question at some
point.
The second half is: let's say this is
a 2-person project, and your partner,
who we'll call "Willie", is
responsible for transforming the
string expression into a tree. You get
the easy part: you need to decide what
classes Willie is to construct the
tree with. You can do it in any
language, but make sure you pick one,
or Willie will hand you assembly
language. If he's feeling ornery, it
will be for a processor that is no
longer manufactured in production.
You'd be amazed at how many candidates
boff this one.
I won't give away the answer, but a
Standard Bad Solution involves the use
of a switch or case statment (or just
good old-fashioned cascaded-ifs). A
Slightly Better Solution involves
using a table of function pointers,
and the Probably Best Solution
involves using polymorphism. I
encourage you to work through it
sometime. Fun stuff!
So, let's try to tackle the problem all three ways. How do you go from an arithmetic expression (e.g. in a string) such as "2 + (2)" to an expression tree using cascaded-if's, a table of function pointers, and/or polymorphism?
Feel free to tackle one, two, or all three.
[update: title modified to better match what most of the answers have been.]
Polymorphic Tree Walking, Python version
#!/usr/bin/python
class Node:
"""base class, you should not process one of these"""
def process(self):
raise('you should not be processing a node')
class BinaryNode(Node):
"""base class for binary nodes"""
def __init__(self, _left, _right):
self.left = _left
self.right = _right
def process(self):
raise('you should not be processing a binarynode')
class Plus(BinaryNode):
def process(self):
return self.left.process() + self.right.process()
class Minus(BinaryNode):
def process(self):
return self.left.process() - self.right.process()
class Mul(BinaryNode):
def process(self):
return self.left.process() * self.right.process()
class Div(BinaryNode):
def process(self):
return self.left.process() / self.right.process()
class Num(Node):
def __init__(self, _value):
self.value = _value
def process(self):
return self.value
def demo(n):
print n.process()
demo(Num(2)) # 2
demo(Plus(Num(2),Num(5))) # 2 + 3
demo(Plus(Mul(Num(2),Num(3)),Div(Num(10),Num(5)))) # (2 * 3) + (10 / 2)
The tests are just building up the binary trees by using constructors.
program structure:
abstract base class: Node
all Nodes inherit from this class
abstract base class: BinaryNode
all binary operators inherit from this class
process method does the work of evaluting the expression and returning the result
binary operator classes: Plus,Minus,Mul,Div
two child nodes, one each for left side and right side subexpressions
number class: Num
holds a leaf-node numeric value, e.g. 17 or 42
The problem, I think, is that we need to parse perentheses, and yet they are not a binary operator? Should we take (2) as a single token, that evaluates to 2?
The parens don't need to show up in the expression tree, but they do affect its shape. E.g., the tree for (1+2)+3 is different from 1+(2+3):
+
/ \
+ 3
/ \
1 2
versus
+
/ \
1 +
/ \
2 3
The parentheses are a "hint" to the parser (e.g., per superjoe30, to "recursively descend")
This gets into parsing/compiler theory, which is kind of a rabbit hole... The Dragon Book is the standard text for compiler construction, and takes this to extremes. In this particular case, you want to construct a context-free grammar for basic arithmetic, then use that grammar to parse out an abstract syntax tree. You can then iterate over the tree, reducing it from the bottom up (it's at this point you'd apply the polymorphism/function pointers/switch statement to reduce the tree).
I've found these notes to be incredibly helpful in compiler and parsing theory.
Representing the Nodes
If we want to include parentheses, we need 5 kinds of nodes:
the binary nodes: Add Minus Mul Divthese have two children, a left and right side
+
/ \
node node
a node to hold a value: Valno children nodes, just a numeric value
a node to keep track of the parens: Parena single child node for the subexpression
( )
|
node
For a polymorphic solution, we need to have this kind of class relationship:
Node
BinaryNode : inherit from Node
Plus : inherit from Binary Node
Minus : inherit from Binary Node
Mul : inherit from Binary Node
Div : inherit from Binary Node
Value : inherit from Node
Paren : inherit from node
There is a virtual function for all nodes called eval(). If you call that function, it will return the value of that subexpression.
String Tokenizer + LL(1) Parser will give you an expression tree... the polymorphism way might involve an abstract Arithmetic class with an "evaluate(a,b)" function, which is overridden for each of the operators involved (Addition, Subtraction etc) to return the appropriate value, and the tree contains Integers and Arithmetic operators, which can be evaluated by a post(?)-order traversal of the tree.
I won't give away the answer, but a
Standard Bad Solution involves the use
of a switch or case statment (or just
good old-fashioned cascaded-ifs). A
Slightly Better Solution involves
using a table of function pointers,
and the Probably Best Solution
involves using polymorphism.
The last twenty years of evolution in interpreters can be seen as going the other way - polymorphism (eg naive Smalltalk metacircular interpreters) to function pointers (naive lisp implementations, threaded code, C++) to switch (naive byte code interpreters), and then onwards to JITs and so on - which either require very big classes, or (in singly polymorphic languages) double-dispatch, which reduces the polymorphism to a type-case, and you're back at stage one. What definition of 'best' is in use here?
For simple stuff a polymorphic solution is OK - here's one I made earlier, but either stack and bytecode/switch or exploiting the runtime's compiler is usually better if you're, say, plotting a function with a few thousand data points.
Hm... I don't think you can write a top-down parser for this without backtracking, so it has to be some sort of a shift-reduce parser. LR(1) or even LALR will of course work just fine with the following (ad-hoc) language definition:
Start -> E1
E1 -> E1+E1 | E1-E1
E1 -> E2*E2 | E2/E2 | E2
E2 -> number | (E1)
Separating it out into E1 and E2 is necessary to maintain the precedence of * and / over + and -.
But this is how I would do it if I had to write the parser by hand:
Two stacks, one storing nodes of the tree as operands and one storing operators
Read the input left to right, make leaf nodes of the numbers and push them into the operand stack.
If you have >= 2 operands on the stack, pop 2, combine them with the topmost operator in the operator stack and push this structure back to the operand tree, unless
The next operator has higher precedence that the one currently on top of the stack.
This leaves us the problem of handling brackets. One elegant (I thought) solution is to store the precedence of each operator as a number in a variable. So initially,
int plus, minus = 1;
int mul, div = 2;
Now every time you see a a left bracket increment all these variables by 2, and every time you see a right bracket, decrement all the variables by 2.
This will ensure that the + in 3*(4+5) has higher precedence than the *, and 3*4 will not be pushed onto the stack. Instead it will wait for 5, push 4+5, then push 3*(4+5).
Re: Justin
I think the tree would look something like this:
+
/ \
2 ( )
|
2
Basically, you'd have an "eval" node, that just evaluates the tree below it. That would then be optimized out to just being:
+
/ \
2 2
In this case the parens aren't required and don't add anything. They don't add anything logically, so they'd just go away.
I think the question is about how to write a parser, not the evaluator. Or rather, how to create the expression tree from a string.
Case statements that return a base class don't exactly count.
The basic structure of a "polymorphic" solution (which is another way of saying, I don't care what you build this with, I just want to extend it with rewriting the least amount of code possible) is deserializing an object hierarchy from a stream with a (dynamic) set of known types.
The crux of the implementation of the polymorphic solution is to have a way to create an expression object from a pattern matcher, likely recursive. I.e., map a BNF or similar syntax to an object factory.
Or maybe this is the real question:
how can you represent (2) as a BST?
That is the part that is tripping me
up.
Recursion.
#Justin:
Look at my note on representing the nodes. If you use that scheme, then
2 + (2)
can be represented as
.
/ \
2 ( )
|
2
should use a functional language imo. Trees are harder to represent and manipulate in OO languages.
As people have been mentioning previously, when you use expression trees parens are not necessary. The order of operations becomes trivial and obvious when you're looking at an expression tree. The parens are hints to the parser.
While the accepted answer is the solution to one half of the problem, the other half - actually parsing the expression - is still unsolved. Typically, these sorts of problems can be solved using a recursive descent parser. Writing such a parser is often a fun exercise, but most modern tools for language parsing will abstract that away for you.
The parser is also significantly harder if you allow floating point numbers in your string. I had to create a DFA to accept floating point numbers in C -- it was a very painstaking and detailed task. Remember, valid floating points include: 10, 10., 10.123, 9.876e-5, 1.0f, .025, etc. I assume some dispensation from this (in favor of simplicty and brevity) was made in the interview.
I've written such a parser with some basic techniques like
Infix -> RPN and
Shunting Yard and
Tree Traversals.
Here is the implementation I've came up with.
It's written in C++ and compiles on both Linux and Windows.
Any suggestions/questions are welcomed.
So, let's try to tackle the problem all three ways. How do you go from an arithmetic expression (e.g. in a string) such as "2 + (2)" to an expression tree using cascaded-if's, a table of function pointers, and/or polymorphism?
This is interesting,but I don't think this belongs to the realm of object-oriented programming...I think it has more to do with parsing techniques.
I've kind of chucked this c# console app together as a bit of a proof of concept. Have a feeling it could be a lot better (that switch statement in GetNode is kind of clunky (it's there coz I hit a blank trying to map a class name to an operator)). Any suggestions on how it could be improved very welcome.
using System;
class Program
{
static void Main(string[] args)
{
string expression = "(((3.5 * 4.5) / (1 + 2)) + 5)";
Console.WriteLine(string.Format("{0} = {1}", expression, new Expression.ExpressionTree(expression).Value));
Console.WriteLine("\nShow's over folks, press a key to exit");
Console.ReadKey(false);
}
}
namespace Expression
{
// -------------------------------------------------------
abstract class NodeBase
{
public abstract double Value { get; }
}
// -------------------------------------------------------
class ValueNode : NodeBase
{
public ValueNode(double value)
{
_double = value;
}
private double _double;
public override double Value
{
get
{
return _double;
}
}
}
// -------------------------------------------------------
abstract class ExpressionNodeBase : NodeBase
{
protected NodeBase GetNode(string expression)
{
// Remove parenthesis
expression = RemoveParenthesis(expression);
// Is expression just a number?
double value = 0;
if (double.TryParse(expression, out value))
{
return new ValueNode(value);
}
else
{
int pos = ParseExpression(expression);
if (pos > 0)
{
string leftExpression = expression.Substring(0, pos - 1).Trim();
string rightExpression = expression.Substring(pos).Trim();
switch (expression.Substring(pos - 1, 1))
{
case "+":
return new Add(leftExpression, rightExpression);
case "-":
return new Subtract(leftExpression, rightExpression);
case "*":
return new Multiply(leftExpression, rightExpression);
case "/":
return new Divide(leftExpression, rightExpression);
default:
throw new Exception("Unknown operator");
}
}
else
{
throw new Exception("Unable to parse expression");
}
}
}
private string RemoveParenthesis(string expression)
{
if (expression.Contains("("))
{
expression = expression.Trim();
int level = 0;
int pos = 0;
foreach (char token in expression.ToCharArray())
{
pos++;
switch (token)
{
case '(':
level++;
break;
case ')':
level--;
break;
}
if (level == 0)
{
break;
}
}
if (level == 0 && pos == expression.Length)
{
expression = expression.Substring(1, expression.Length - 2);
expression = RemoveParenthesis(expression);
}
}
return expression;
}
private int ParseExpression(string expression)
{
int winningLevel = 0;
byte winningTokenWeight = 0;
int winningPos = 0;
int level = 0;
int pos = 0;
foreach (char token in expression.ToCharArray())
{
pos++;
switch (token)
{
case '(':
level++;
break;
case ')':
level--;
break;
}
if (level <= winningLevel)
{
if (OperatorWeight(token) > winningTokenWeight)
{
winningLevel = level;
winningTokenWeight = OperatorWeight(token);
winningPos = pos;
}
}
}
return winningPos;
}
private byte OperatorWeight(char value)
{
switch (value)
{
case '+':
case '-':
return 3;
case '*':
return 2;
case '/':
return 1;
default:
return 0;
}
}
}
// -------------------------------------------------------
class ExpressionTree : ExpressionNodeBase
{
protected NodeBase _rootNode;
public ExpressionTree(string expression)
{
_rootNode = GetNode(expression);
}
public override double Value
{
get
{
return _rootNode.Value;
}
}
}
// -------------------------------------------------------
abstract class OperatorNodeBase : ExpressionNodeBase
{
protected NodeBase _leftNode;
protected NodeBase _rightNode;
public OperatorNodeBase(string leftExpression, string rightExpression)
{
_leftNode = GetNode(leftExpression);
_rightNode = GetNode(rightExpression);
}
}
// -------------------------------------------------------
class Add : OperatorNodeBase
{
public Add(string leftExpression, string rightExpression)
: base(leftExpression, rightExpression)
{
}
public override double Value
{
get
{
return _leftNode.Value + _rightNode.Value;
}
}
}
// -------------------------------------------------------
class Subtract : OperatorNodeBase
{
public Subtract(string leftExpression, string rightExpression)
: base(leftExpression, rightExpression)
{
}
public override double Value
{
get
{
return _leftNode.Value - _rightNode.Value;
}
}
}
// -------------------------------------------------------
class Divide : OperatorNodeBase
{
public Divide(string leftExpression, string rightExpression)
: base(leftExpression, rightExpression)
{
}
public override double Value
{
get
{
return _leftNode.Value / _rightNode.Value;
}
}
}
// -------------------------------------------------------
class Multiply : OperatorNodeBase
{
public Multiply(string leftExpression, string rightExpression)
: base(leftExpression, rightExpression)
{
}
public override double Value
{
get
{
return _leftNode.Value * _rightNode.Value;
}
}
}
}
Ok, here is my naive implementation. Sorry, I did not feel to use objects for that one but it is easy to convert. I feel a bit like evil Willy (from Steve's story).
#!/usr/bin/env python
#tree structure [left argument, operator, right argument, priority level]
tree_root = [None, None, None, None]
#count of parethesis nesting
parenthesis_level = 0
#current node with empty right argument
current_node = tree_root
#indices in tree_root nodes Left, Operator, Right, PRiority
L, O, R, PR = 0, 1, 2, 3
#functions that realise operators
def sum(a, b):
return a + b
def diff(a, b):
return a - b
def mul(a, b):
return a * b
def div(a, b):
return a / b
#tree evaluator
def process_node(n):
try:
len(n)
except TypeError:
return n
left = process_node(n[L])
right = process_node(n[R])
return n[O](left, right)
#mapping operators to relevant functions
o2f = {'+': sum, '-': diff, '*': mul, '/': div, '(': None, ')': None}
#converts token to a node in tree
def convert_token(t):
global current_node, tree_root, parenthesis_level
if t == '(':
parenthesis_level += 2
return
if t == ')':
parenthesis_level -= 2
return
try: #assumption that we have just an integer
l = int(t)
except (ValueError, TypeError):
pass #if not, no problem
else:
if tree_root[L] is None: #if it is first number, put it on the left of root node
tree_root[L] = l
else: #put on the right of current_node
current_node[R] = l
return
priority = (1 if t in '+-' else 2) + parenthesis_level
#if tree_root does not have operator put it there
if tree_root[O] is None and t in o2f:
tree_root[O] = o2f[t]
tree_root[PR] = priority
return
#if new node has less or equals priority, put it on the top of tree
if tree_root[PR] >= priority:
temp = [tree_root, o2f[t], None, priority]
tree_root = current_node = temp
return
#starting from root search for a place with higher priority in hierarchy
current_node = tree_root
while type(current_node[R]) != type(1) and priority > current_node[R][PR]:
current_node = current_node[R]
#insert new node
temp = [current_node[R], o2f[t], None, priority]
current_node[R] = temp
current_node = temp
def parse(e):
token = ''
for c in e:
if c <= '9' and c >='0':
token += c
continue
if c == ' ':
if token != '':
convert_token(token)
token = ''
continue
if c in o2f:
if token != '':
convert_token(token)
convert_token(c)
token = ''
continue
print "Unrecognized character:", c
if token != '':
convert_token(token)
def main():
parse('(((3 * 4) / (1 + 2)) + 5)')
print tree_root
print process_node(tree_root)
if __name__ == '__main__':
main()