A fast coding tip somehow ended up making the code slower in Julia - optimization

I heard that being conscious of type-stability contributes a lot to the high performance in Julia programming, so I tried to measure how much time I can save when rewriting the type-unstable function into type-stable version.
As many people say, I assumed that type-stable coding of course has higher performance than type-unstable one. However, the result was otherwise:
# type-unstable vs type-stable
# type-unstable
function positive(x)
if x < 0
return 0.0
else
return x
end
end
# type-stable
function positive_safe(x)
if x < 0
return zero(x)
else
return x
end
end
#time for n in 1:100_000_000
a = 2^( positive(-n) + 1 )
end
#time for n in 1:100_000_000
b = 2^( positive_safe(-n) + 1 )
end
result:
0.040080 seconds
0.150596 seconds
I cannot believe this. Are there some mistakes in my code? Or this is the fact?
Any information would be appreciated.
Context
Operating System and version: Windows 10
Browser and version: Google Chrome 90.0.4430.212(Official Build) (64 bit)
JupyterLab version: 3.0.14
#btime result
Just replacing #time with #btime for my code above
#btime for n in 1:100_000_000
a = 2^( positive(-n) + 1 )
end
# -> 1.500 ns
#btime for n in 1:100_000_000
b = 2^( positive_safe(-n) + 1 )
end
# -> 503.146 ms
Still weird.
the exact same code DNF showed me
using BenchmarkTools
#btime 2^(positive(-n) + 1) setup=(n=rand(1:10^8))
# -> 32.435 ns (0 allocations: 0 bytes)
#btime 2^(positive_safe(-n) + 1) setup=(n=rand(1:10^8))
#-> 3.103 ns (0 allocations: 0 bytes)
Works as expected.
I still don't understand what is happening.
I feel like I have to know better about the usage of #btime and benchmarking process.
By the way, as I said above, I'm trying this benchmarking on Jupyterlab.

The problem with your benchmark, you testing different logic code:
2 ^ (integer value)
and
2 ^ (float value)
But the most crucial part, if a and b is not defined before the loop, Julia compiler may remove the block. Your performance very much depends was the a and b defined before and were defined in the global scope or not.
And power is the time-consuming central part of your code (not the type unstable part).
positive function returns Float in your case, positive_safe returns Int)
The code similar to your case (by logic) could look like that:
# type-unstable
function positive(x)
if x < 0
return 0.0
else
return x
end
end
# type-stable
function positive_safe(x)
if x < 0
return 0.0
else
return Float64(x)
end
end
function test1()
a = 0.0
for n in 1:100_000_000
a += 2^( positive(-n) + 1 )
end
a
end
function test2()
b = 0.0
for n in 1:100_000_000
b += 2^( positive_safe(-n) + 1 )
end
b
end
#btime test1()
#btime test2()
98.045 ms (0 allocations: 0 bytes)
2.0e8
97.948 ms (0 allocations: 0 bytes)
2.0e8
The results are almost the same since your type unstable is not a bottleneck for the case.
If to test the function (which is similar to your case when a/b was not defined):
function test3()
b = 0.0
for n in 1:100_000_000
b += 2^( positive_safe(-n) + 1 )
end
nothing
end
#btime test3()
Benchmark will show results:
1.611 ns
This is not because my laptop did 100_000_000 iterations per 1.611 ns, but because Julia compiler smart enough to understand that the test3 function may be replaced with nothing.

This is benchmarking problem. The #time macro is not suitable for microbenchmarks. Use the BenchmarkTools.jl package, and read the user manual. It is easy to make mistakes when benchmarking.
Here's how to do it:
jl> using BenchmarkTools
jl> #btime 2^(positive(-n) + 1) setup=(n=rand(1:10^8))
6.507 ns (0 allocations: 0 bytes)
2.0
jl> #btime 2^(positive_safe(-n) + 1) setup=(n=rand(1:10^8))
2.100 ns (0 allocations: 0 bytes)
2
As you see, the type stable function is faster.

The problem, as Vitaliy said, is that powers in floating point done with logs can be faster than the integer ones that can be done as loop multiplies:
using BenchmarkTools
# type-unstable vs type-unstable
# type-unstable
function positive_float_unstable(x)
if x < 0
return 0.0
else
return x
end
end
# type-unstable
function positive_int_unstable(x)
if x < 0
return 0
else
return x
end
end
# type-stable
function positive_float_stable(x)
if x < 0
return 0.0
else
return Float64(x)
end
end
# type-stable
function positive_int_stable(x)
if x < 0
return 0
else
return Int(x)
end
end
println("unstable float")
#btime for n in 1:100_000_000
a = 2^( positive_float_unstable(-n) + 1 )
end
println("unstable int")
#btime for n in 1:100_000_000
b = 2^( positive_int_unstable(-n) + 1 )
end
println("stable float")
#btime for n in 1:100_000_000
a = 2^( positive_float_stable(-n) + 1 )
end
println("stable int")
#btime for n in 1:100_000_000
b = 2^( positive_int_stable(-n) + 1 )
end
Results:
unstable float
1.300 ns (0 allocations: 0 bytes)
unstable int
179.232 ms (0 allocations: 0 bytes)
stable float
1.300 ns (0 allocations: 0 bytes)
stable int
178.990 ms (0 allocations: 0 bytes)

Related

Is there an algorithm, to find values ​of a polynomial with big integers, quickly without loops?

For example, if I want to find
1085912312763120759250776993188102125849391224162 = a^9+b^9+c^9+d
the code needs to brings
a=3456
b=78525
c=217423
d=215478
I do not need specific values, only that they comply with the fact that a, b and c have 6 digits at most and d is as small as possible.
Is there a quick way to find it?
I appreciate any help you can give me.
I have tried with nested loops but it is extremely slow and the code gets stuck.
Any help in VB or other code would be appreciated. I think the structure is more important than the language in this case
Imports System.Numerics
Public Class Form1
Private Sub Button1_Click(sender As Object, e As EventArgs) Handles Button1.Click
Dim Value As BigInteger = BigInteger.Parse("1085912312763120759250776993188102125849391224162")
Dim powResult As BigInteger
Dim dResult As BigInteger
Dim a As Integer
Dim b As Integer
Dim c As Integer
Dim d As Integer
For i = 1 To 999999
For j = 1 To 999999
For k = 1 To 999999
powResult = BigInteger.Add(BigInteger.Add(BigInteger.Pow(i, 9), BigInteger.Pow(j, 9)), BigInteger.Pow(k, 9))
dResult = BigInteger.Subtract(Value, powResult)
If Len(dResult.ToString) <= 6 Then
a = i
b = j
c = k
d = dResult
RichTextBox1.Text = a & " , " & b & " , " & c & " , " & d
Exit For
Exit For
Exit For
End If
Next
Next
Next
End Sub
End Class
UPDATE
I wrote the code in vb. But with this code, a is correct, b is correct but c is incorrect, and the result is incorrect.
a^9 + b^9 + c^9 + d is a number bigger than the initial value.
The code should brings
a= 217423
b= 78525
c= 3456
d= 215478
Total Value is ok= 1085912312763120759250776993188102125849391224162
but code brings
a= 217423
b= 78525
c= 65957
d= 70333722607339201875244531009974
Total Value is bigger and not equal=1085935936469985777155428248430866412402362281319
Whats i need to change in the code to make c= 3456 and d= 215478?
the code is
Imports System.Numerics
Public Class Form1
Private Function pow9(x As BigInteger) As BigInteger
Dim y As BigInteger
y = x * x ' x^2
y *= y ' x^4
y *= y ' x^8
y *= x ' x^9
Return y
End Function
Private Sub Button1_Click(sender As Object, e As EventArgs) Handles Button1.Click
Dim a, b, c, d, D2, n As BigInteger
Dim aa, bb, cc, dd, ae As BigInteger
D2 = BigInteger.Parse("1085912312763120759250776993188102125849391224162")
'first solution so a is maximal
d = D2
'a = BigIntegerSqrt(D2)
'RichTextBox1.Text = a.ToString
For a = 1 << ((Convert.ToInt32(Math.Ceiling(BigInteger.Log(d, 2))) + 8) / 9) To a > 0 Step -1
If (pow9(a) <= d) Then
d -= pow9(a)
Exit For
End If
Next
For b = 1 << ((Convert.ToInt32(Math.Ceiling(BigInteger.Log(d, 2))) + 8) / 9) To b > 0 Step -1
If (pow9(b) <= d) Then
d -= pow9(b)
Exit For
End If
Next
For c = 1 << ((Convert.ToInt32(Math.Ceiling(BigInteger.Log(d, 2))) + 8) / 9) To c > 0 Step -1
If (pow9(c) <= d) Then
d -= pow9(c)
Exit For
End If
Next
' minimize d
aa = a
bb = b
cc = c
dd = d
If (aa < 10) Then
ae = 0
Else
ae = aa - 10
End If
For a = aa - 1 To a > ae Step -1 'a goes down few iterations
d = D2 - pow9(a)
For n = 1 << ((Convert.ToInt32(Math.Ceiling(BigInteger.Log(d, 2))) + 8) / 9) To b < n 'b goes up
If (pow9(b) >= d) Then
b = b - 1
d -= pow9(b)
Exit For
End If
Next
For c = 1 << ((Convert.ToInt32(Math.Ceiling(BigInteger.Log(d, 2))) + 8) / 9) To c > 0 Step -1 'c must be search fully
If pow9(c) <= d Then
d -= pow9(c)
Exit For
End If
Next
If d < dd Then 'remember better solution
aa = a
bb = b
cc = c
dd = d
End If
If a < ae Then
Exit For
End If
Next
a = aa
b = bb
c = cc
d = dd
' a,b,c,d is the result
RichTextBox1.Text = D2.ToString
Dim Sum As BigInteger
Dim a9 As BigInteger
Dim b9 As BigInteger
Dim c9 As BigInteger
a9 = BigInteger.Pow(a, 9)
b9 = BigInteger.Pow(b, 9)
c9 = BigInteger.Pow(c, 9)
Sum = BigInteger.Add(BigInteger.Add(BigInteger.Add(a9, b9), c9), d)
RichTextBox2.Text = Sum.ToString
Dim Subst As BigInteger
Subst = BigInteger.Subtract(Sum, D2)
RichTextBox3.Text = Subst.ToString
End Sub
End Class
[Update]
The below code is an attempt to solve a problem like OP's, yet I erred in reading it.
The below is for 1085912312763120759250776993188102125849391224162 = a^9+b^9+c^9+d^9+e and to minimize e.
Just became too excite about OP's interesting conundrum and read too quick.
I review this more later.
OP's approach is O(N*N*N*N) - slow
Below is a O(N*N*log(N)) one.
Algorithm
Let N = 1,000,000. (Looks like 250,000 is good enough for OP's sum of 1.0859e48.)
Define 160+ wide integer math routines.
Define type: pow9
int x,y,
int160least_t z
Form array pow9 a[N*N] populated with x, y, x^9 + y^9, for every x,y in the [1...N] range.
Sort array on z.
Cost so far O(N*N*log(N).
For array elements indexed [0... N*N/2] do a binary search for another array element such that the sum is 1085912312763120759250776993188102125849391224162
Sum closest is the answer.
Time: O(N*N*log(N))
Space: O(N*N)
Easy to start with FP math and then later get a better answer with crafter extended integer math.
Try with smaller N and total sum targets to iron out implementation issues.
In case a,b,c,d might be zero I got an Idea for fast and simple solution:
First something better than brute force search of a^9 + d = x so that a is maximal (that ensures minimal d)...
let d = 1085912312763120759250776993188102125849391224162
find max value a such that a^9 <= d
this is simple as we know 9th power will multiply the bitwidth of operand 9 times so the max value can be at most a <= 2^(log2(d)/9) Now just search all numbers from this number down to zero (decrementing) until its 9th power is less or equal to x. This value will be our a.
Its still brute force search however from much better starting point so much less iterations are required.
We also need to update d so let
d = d - a^9
Now just find b,c in the same way (using smaller and smaller remainder d)... these searches are not nested so they are fast ...
b^9 <= d; d-=b^9;
c^9 <= d; c-=b^9;
To improve speed even more you can hardcode the 9th power using power by squaring ...
This will be our initial solution (on mine setup it took ~200ms with 32*8 bits uints) with these results:
x = 1085912312763120759250776993188102125849391224162
1085912312763120759250776993188102125849391224162 (reference)
a = 217425
b = 65957
c = 22886
d = 39113777348346762582909125401671564
Now we want to minimize d so simply decrement a and search b upwards until still a^9 + b^9 <= d is lower. Then search c as before and remember better solution. The a should be search downwards to meet b in the middle but as both a and bhave the same powers only few iterations might suffice (I used 50) from the first solution (but I have no proof of this its just my feeling). But still even if full range is used this has less complexity than yours as I have just 2 nested fors instead of yours 3 and they all are with lower ranges...
Here small working C++ example (sorry do not code in BASIC for decades):
//---------------------------------------------------------------------------
typedef uint<8> bigint;
//---------------------------------------------------------------------------
bigint pow9(bigint &x)
{
bigint y;
y=x*x; // x^2
y*=y; // x^4
y*=y; // x^8
y*=x; // x^9
return y;
}
//---------------------------------------------------------------------------
void compute()
{
bigint a,b,c,d,D,n;
bigint aa,bb,cc,dd,ae;
D="1085912312763120759250776993188102125849391224162";
// first solution so a is maximal
d=D;
for (a=1<<((d.bits()+8)/9);a>0;a--) if (pow9(a)<=d) break; d-=pow9(a);
for (b=1<<((d.bits()+8)/9);b>0;b--) if (pow9(b)<=d) break; d-=pow9(b);
for (c=1<<((d.bits()+8)/9);c>0;c--) if (pow9(c)<=d) break; d-=pow9(c);
// minimize d
aa=a; bb=b; cc=c; dd=d;
if (aa<50) ae=0; else ae=aa-50;
for (a=aa-1;a>ae;a--) // a goes down few iterations
{
d=D-pow9(a);
for (n=1<<((d.bits()+8)/9),b++;b<n;b++) if (pow9(b)>=d) break; b--; d-=pow9(b); // b goes up
for (c=1<<((d.bits()+8)/9);c>0;c--) if (pow9(c)<=d) break; d-=pow9(c); // c must be search fully
if (d<dd) // remember better solution
{
aa=a; bb=b; cc=c; dd=d;
}
}
a=aa; b=bb; c=cc; d=dd; // a,b,c,d is the result
}
//-------------------------------------------------------------------------
The function bits() just returns number of occupied bits (similar to log2 but much faster). Here final results:
x = 1085912312763120759250776993188102125849391224162
1085912312763120759250776993188102125849391224162 (reference)
a = 217423
b = 78525
c = 3456
d = 215478
It took 1689.651 ms ... As you can see this is much faster than yours however I am not sure with the number of search iterations while fine tuning ais OK or it should be scaled by a/b or even full range down to (a+b)/2 which will be much slower than this...
One last thing I did not bound a,b,c to 999999 so if you want it you just add if (a>999999) a=999999; statement after any a=1<<((d.bits()+8)/9)...
[Edit1] adding binary search
Ok now all the full searches for 9th root (except of the fine tunnig of a) can be done with binary search which will improve speed a lot more while ignoring bigint multiplication complexity leads to O(n.log(n)) against your O(n^3)... Here updated code (will full iteration of a while fitting so its safe):
//---------------------------------------------------------------------------
typedef uint<8> bigint;
//---------------------------------------------------------------------------
bigint pow9(bigint &x)
{
bigint y;
y=x*x; // x^2
y*=y; // x^4
y*=y; // x^8
y*=x; // x^9
return y;
}
//---------------------------------------------------------------------------
bigint binsearch_max_pow9(bigint &d) // return biggest x, where x^9 <= d, and lower d by x^9
{ // x = floor(d^(1/9)) , d = remainder
bigint m,x;
for (m=bigint(1)<<((d.bits()+8)/9),x=0;m.isnonzero();m>>=1)
{ x|=m; if (pow9(x)>d) x^=m; }
d-=pow9(x);
return x;
}
//---------------------------------------------------------------------------
void compute()
{
bigint a,b,c,d,D,n;
bigint aa,bb,cc,dd;
D="1085912312763120759250776993188102125849391224162";
// first solution so a is maximal
d=D;
a=binsearch_max_pow9(d);
b=binsearch_max_pow9(d);
c=binsearch_max_pow9(d);
// minimize d
aa=a; bb=b; cc=c; dd=d;
for (a=aa-1;a>=b;a--) // a goes down few iterations
{
d=D-pow9(a);
for (n=1<<((d.bits()+8)/9),b++;b<n;b++) if (pow9(b)>=d) break; b--; d-=pow9(b); // b goes up
c=binsearch_max_pow9(d);
if (d<dd) // remember better solution
{
aa=a; bb=b; cc=c; dd=d;
}
}
a=aa; b=bb; c=cc; d=dd; // a,b,c,d is the result
}
//-------------------------------------------------------------------------
function m.isnonzero() is the same as m!=0 just faster... The results are the same as above code but the time duration is only 821 ms for full iteration of a which would be several thousands seconds with previous code.
I think except using some polynomial discrete math trick I do not know of there is only one more thing to improve and that is to compute consequent pow9 without multiplication which will boost the speed a lot (as bigint multiplication is slowest operation by far) like I did in here:
How to get a square root for 32 bit input in one clock cycle only?
but I am too lazy to derive it...

Reducing memory allocation of a generator in Julia

I am trying to reduce the memory allocation of an inner loop in my code. Below the part that is not working as expected.
using Random
using StatsBase
using BenchmarkTools
using Distributions
a_dist = Distributions.DiscreteUniform(1, 99)
v_dist = Distributions.DiscreteUniform(1, 2)
population_size = 10000
population = [rand(a_dist, population_size) rand(v_dist, population_size)]
find_all_it3(f::Function, A) = (p[2] for p in eachrow(A) if f(p[1]))
#btime begin
c_pool = find_all_it3(x -> (x < 5), population)
c_pool_dict = countmap(c_pool, alg=:dict)
end
#btime begin
c_pool_indexes = findall(x -> (x < 5) , view(population, :, 1))
c_pool_dict = countmap(population[c_pool_indexes, 2], alg=:dict)
end
I was hoping that the generator (find_all_it3) would not need to allocate much memory.
however as per the btime output it seems that there is an allocation for each loop.
98.040 μs (10006 allocations: 625.64 KiB)
18.894 μs (18 allocations: 11.95 KiB)
Now in my scenario the speed and allocation of the findall eventually become an issue, hence I was trying to find a better alternative through generator/iterators so that less allocation occur; is there a way to do that? Are there options to consider?
I don't have an explaination for it but here are the results of a few tests I made
The best time is achieved with view(population, :, 1) .< 5 (test4)
using broadcast! reduces allocations a bit (test5)
the best way to reduce allocation is to do your own loop (test6)
using BenchmarkTools
using StatsBase
population_size = 10000
population = [rand(1:99, population_size) rand(1:2, population_size)]
find_all_it(f::Function, A) = (p[2] for p in eachrow(A) if f(p[1]))
function test1(population)
c_pool = find_all_it(x -> x < 5, population)
c_pool_dict = countmap(c_pool, alg=:dict)
end
function test3(population)
c_pool_indexes = findall(x -> x < 5, view(population, :, 1))
c_pool_dict = countmap(view(population,c_pool_indexes, 2), alg=:dict)
end
function test4(population)
c_pool_indexes = view(population, :, 1) .< 5
c_pool_dict = countmap(view(population,c_pool_indexes, 2), alg=:dict)
end
function test5(c_pool_indexes, population)
broadcast!(<, c_pool_indexes, view(population, :, 1), 5)
c_pool_dict = countmap(view(population,c_pool_indexes, 2), alg=:dict)
end
function test6(population)
d = Dict{Int,Int}()
for i in eachindex(view(population, :, 1))
if population[i, 1] < 5
d[population[i,2]] = 1 + get(d,population[i,2],0)
end
end
return d
end
julia> #btime test1(population);
68.200 μs (10004 allocations: 625.59 KiB)
julia> #btime test3(population);
14.800 μs (14 allocations: 9.00 KiB)
julia> #btime test4(population);
7.250 μs (8 allocations: 9.33 KiB)
julia> temp = zeros(Bool, population_size);
julia> #btime test5(temp, population);
16.599 μs (5 allocations: 3.78 KiB)
julia> #btime test6(population);
11.299 μs (4 allocations: 608 bytes)

"Not defined variable" in a 'while loop' in Julia

I am trying to do a sensitivity analysis in Julia using JuMP. Here is my code:
using JuMP, Plots, Gurobi
m=Model(with_optimizer(Gurobi.Optimizer))
#variable(m, x>=0)
#variable(m, y>=0)
#variable(m, k>=0)
k = 0
while k<=1
φ(x,y,k)=3*x+k*y
#objective(m, Max, φ(x,y,k))
#constraint(m, 2*x-4>=0)
#constraint(m, y-0.5*x>=0)
pl=optimize!(m)
k=k+0.2
end
The problem is that I get an error:
UndefVarError: k not defined
What am I missing?
julia> k =0
0
julia> while k<10
k=k+1
end
ERROR: UndefVarError: k not defined
Stacktrace:
[1] top-level scope at ./REPL[11]:2
In julia if we are operating with loops the variables we initialise outside our loop can not be directly accessed within a loop on default. To do that we have to set those variable on to global use as on default they are considered to be local
julia> while k<10
global k=k+1
end
Now this works fine
Disclaimer: This is an alternative solution, that it was suggested by a member of Julia Discorse
In a discussion at Julia Discourse, it is suggested to wrap the code in a function in order to increase speed and to avoid the global issue:
function run_code()
model = Model(with_optimizer(Gurobi.Optimizer))
#variable(model, x >= 0)
#variable(model, y >= 0)
#constraint(model, 2x - 4 >= 0)
k = 0
while k <= 1
#objective(model, Max, 3x + k * y)
optimize!(model)
k = k + 0.2
end
end
run_code()

My program in Julia sees syntax error where there is none

I have a problem with this bit of code. Every time I try to run it it says that I have "Unexpected" end. For me everything is on point and I cant figure it out can someone help me find solution? Full error code and program code below.
Program:
function mbisekcji(f, a::Float64, b::Float64, delta::Float64, epsilon::Float64)
e = b-a
u = f(a)
v = f(b)
err = 0
iterator = 0
if sign(u) == sign(v)
err = 1
return err
end
while true
e = e/2
c = a+e
w = f(c)
if (norm(e) < delta) || (norm(w) < epsilon)
return w, f(w), iterator, err
end
if sign(w) == sign(u)
b = c
v = w
else
a = c
u = w
end
iterator++
end
end
Error:
LoadError: [91msyntax: unexpected "end"[39m
while loading C:\Users\username\Desktop\Study\zad1.jl, in expression starting on line 60
include_string(::String, ::String) at loading.jl:522
include_string(::Module, ::String, ::String) at Compat.jl:84
(::Atom.##112#116{String,String})() at eval.jl:109
withpath(::Atom.##112#116{String,String}, ::String) at utils.jl:30
withpath(::Function, ::String) at eval.jl:38
hideprompt(::Atom.##111#115{String,String}) at repl.jl:67
macro expansion at eval.jl:106 [inlined]
(::Atom.##110#114{Dict{String,Any}})() at task.jl:80
Also, just to make thing easier, line 60 is second end from the back. The one closing while loop.
In order to increment a variable by 1 in Julia you have to write
iterator += 1
Julia does not support ++ to increment a variable.
But, for example, you could define a macro to do almost what you want:
julia> macro ++(x)
esc(:($x += 1))
end
#++ (macro with 1 method)
julia> x = 1
1
julia> #++x
2
julia> x
2

Simulate data for repeated binary measures

I can generate a binary variable y as follows:
clear
set more off
gen y =.
replace y = rbinomial(1, .5)
How can I generate n variables y_1, y_2, ..., y_n with a correlation of rho?
This is #pjs's solution in Stata for generating pairs of variables:
clear
set obs 100
set seed 12345
generate x = rbinomial(1, 0.7)
generate y = rbinomial(1, 0.7 + 0.2 * (1 - 0.7)) if x == 1
replace y = rbinomial(1, 0.7 * (1 - 0.2)) if x != 1
summarize x y
Variable | Obs Mean Std. Dev. Min Max
-------------+---------------------------------------------------------
x | 100 .72 .4512609 0 1
y | 100 .67 .4725816 0 1
correlate x y
(obs=100)
| x y
-------------+------------------
x | 1.0000
y | 0.1781 1.0000
And a simulation:
set seed 12345
tempname sim1
tempfile mcresults
postfile `sim1' mu_x mu_y rho using `mcresults', replace
forvalues i = 1 / 100000 {
quietly {
clear
set obs 100
generate x = rbinomial(1, 0.7)
generate y = rbinomial(1, 0.7 + 0.2 * (1 - 0.7)) if x == 1
replace y = rbinomial(1, 0.7 * (1 - 0.2)) if x != 1
summarize x, meanonly
scalar mean_x = r(mean)
summarize y, meanonly
scalar mean_y = r(mean)
corr x y
scalar rho = r(rho)
post `sim1' (mean_x) (mean_y) (rho)
}
}
postclose `sim1'
use `mcresults', clear
summarize *
Variable | Obs Mean Std. Dev. Min Max
-------------+---------------------------------------------------------
mu_x | 100,000 .7000379 .0459078 .47 .89
mu_y | 100,000 .6999094 .0456385 .49 .88
rho | 100,000 .1993097 .1042207 -.2578483 .6294388
Note that in this example I use p = 0.7 and rho = 0.2 instead.
This is #pjs's solution in Stata for generating a time-series:
clear
set seed 12345
set obs 1
local p = 0.7
local rho = 0.5
generate y = runiform()
if y <= `p' replace y = 1
else replace y = 0
forvalues i = 1 / 99999 {
set obs `= _N + 1'
local rnd = runiform()
if y[`i'] == 1 {
if `rnd' <= `p' + `rho' * (1 - `p') replace y = 1 in `= `i' + 1'
else replace y = 0 in `= `i' + 1'
}
else {
if `rnd' <= `p' * (1 - `rho') replace y = 1 in `= `i' + 1'
else replace y = 0 in `= `i' + 1'
}
}
Results:
summarize y
Variable | Obs Mean Std. Dev. Min Max
-------------+---------------------------------------------------------
y | 100,000 .70078 .4579186 0 1
generate id = _n
tsset id
corrgram y, lags(5)
-1 0 1 -1 0 1
LAG AC PAC Q Prob>Q [Autocorrelation] [Partial Autocor]
-------------------------------------------------------------------------------
1 0.5036 0.5036 25366 0.0000 |---- |----
2 0.2567 0.0041 31955 0.0000 |-- |
3 0.1273 -0.0047 33576 0.0000 |- |
4 0.0572 -0.0080 33903 0.0000 | |
5 0.0277 0.0032 33980 0.0000 | |
Correlation is a pairwise measure, so I'm assuming that when you talk about binary (Bernoulli) values Y1,...,Yn having a correlation of rho you're viewing them as a time series Yi: i = 1,...,n, of Bernoulli values having a common mean p, variance p*(1-p), and a lag 1 correlation of rho.
I was able to work it out using the definition of correlation and conditional probability. Given it was a bunch of tedious algebra and stackoverflow doesn't do math gracefully, I'm jumping straight to the result, expressed in pseudocode:
if Y[i] == 1:
generate Y[i+1] as Bernoulli(p + rho * (1 - p))
else:
generate Y[i+1] as Bernoulli(p * (1 - rho))
As a sanity check you can see that if rho = 0 it just generates Bernoulli(p)'s, regardless of the prior value. As you already noted in your question, Bernoulli RVs are binomials with n = 1.
This works for all 0 <= rho, p <= 1. For negative correlations, there are constraints on the relative magnitudes of p and rho so that the parameters of the Bernoullis are always between 0 and 1.
You can analytically check the conditional probabilities to confirm correctness. I don't use Stata, but I tested this pretty thoroughly in the JMP statistical software and it works like a charm.
IMPLEMENTATION (Python)
import random
def Bernoulli(p):
return 1 if random.random() <= p else 0 # yields 1 w/ prob p, 0 otherwise
N = 100000
p = 0.7
rho = 0.5
last_y = Bernoulli(p)
for _ in range(N):
if last_y == 1:
last_y = Bernoulli(p + rho * (1 - p))
else:
last_y = Bernoulli(p * (1 - rho))
print(last_y)
I ran this and redirected the results to a file, then imported the file into JMP. Analyzing it as a time series produced:
The sample mean was 0.69834, with a standard deviation of 0.4589785 [upper right of the figure]. The lag-1 estimates for autocorrelation and partial correlation are 0.5011 [bottom left and right, respectively]. These estimated values are all excellent matches to a Bernoulli(0.7) with rho = 0.5, as specified in the demo program.
If the goal is instead to produce (X,Y) pairs with the specified correlation, revise the loop to:
for _ in range(N):
x = Bernoulli(p)
if x == 1:
y = Bernoulli(p + rho * (1 - p))
else:
y = Bernoulli(p * (1 - rho))
print(x, y)