computational complexity of higher order derivatives with AD in jax - tensorflow

Let f: R -> R be an infinitely differentiable function. What is the computational complexity of calculating the first n derivatives of f in Jax? Naive chain rule would suggest that each multiplication gives a factor of 2 increase, hence the nth derivative would require at least 2^n more operations. I imagine though that clever manipulation of formal series would reduce the number of required calculations and eliminate duplications, esspecially if the derivaives are Jax jitted? Is there a different between the Jax, Tensorflow and Torch implementations?
https://openreview.net/forum?id=SkxEF3FNPH discusses this topic, but doesn t provide a computational complexity.

What is the computational complexity of calculating the first n derivatives of f in Jax?
There's not much you can say in general about computational complexity of Nth derivatives. For example, with a function like jnp.sin, the Nth derivative is O[1], oscillating between negative and positive sin and cos calls as N grows. For an order-k polynomial, the Nth derivative is O[0] for N > k. Other functions may have complexity that is linear or polynomial or even exponential with N depending on the operations they contain.
I imagine though that clever manipulation of formal series would reduce the number of required calculations and eliminate duplications, esspecially if the derivaives are Jax jitted
You imagine correctly! One implementation of this idea is the jax.experimental.jet module, which is an experimental transform designed for computing higher-order derivatives efficiently and accurately. It doesn't cover all JAX functions, but it may be complete enough to do what you have in mind.

Related

kNN-DTW time complexity

I found from various online sources that the time complexity for DTW is quadratic. On the other hand, I also found that standard kNN has linear time complexity. However, when pairing them together, does kNN-DTW have quadratic or cubic time?
In essence, does the time complexity of kNN solely depend on the metric used? I have not found any clear answer for this.
You need to be careful here. Let's say you have n time series in your 'training' set (let's call it this, even though you are not really training with kNN) of length l. Computing the DTW between a pair of time series has a asymptotic complexity of O(l * m) where m is your maximum warping window. As m <= l also O(l^2) holds. (although there might be more efficient implementations, i don't think they are actually faster in practice in most cases, see here). Classifying a time series using kNN requires you to compute the distance between that time series and all time series in the training set which would mean n comparisons, linear with respect to n.
So your final complexity would be in O(l * m * n) or O(l^2 * n). In words: the complexity is quadratic with respect to time series length and linear with respect to the number of training examples.

Asynchrony loss function over an array of 1D signals

So I have an array of N 1D-signals (e.g. time series) with same number of samples per signal (all in equal resolution) and I want to define a differentiable loss function to penalize asynchrony among them and therefore be zero if all N 1D signals will be equal to each other. I've been searching the literature to find something but haven't had luck yet.
Few remarks:
1 - since N (number of signals) could be quite large I can not afford to calculate Mean squared loss between every single pair which could grow combinatorialy large. also I'm not quite sure whether it would be optimal in any mathematical sense for the goal to achieve.
There are two naive loss functions that I could think of :
a) Total variation loss for each time sample across all signals (to force to reach ideally zero variation). the problem is here the weight needs to be very large to yield zero varion. masking any other loss term that is going to be added and also there is no inherent order among the N signals, which doesnt make it suitable to TV loss to begin with.
b) minimizing the sum of variance at each time point among all signals. however, choice of the reference of variance (aka mean) could be crucial I believe as just using the sample mean might not really yield the desired result, not quite sure.

Time complexity (Big-O notation) of Posterior Probability Calculation

I got a basic idea of Big-O notation from Big-O notation's definition.
In my problem, a 2-D surface is divided into uniform M grids. Each grid (m) is assigned with a posterior probability based on A features.
The posterior probability of m grid is calculated as follows:
and the marginal likelihood is given as:
Here, A features are independent of each other and sigma and mean symbol represent the standard deviation and mean value of each a feature at each grid. I need to calculate the Posterior probability of all M grids.
What will be the time complexity of the above operation in terms of Big-O notation?
My guess is O(M) or O(M+A). Am I correct? I'm expecting an authenticate answer to present at the formal forum.
Also, what will be the time complexity if M grids are divided into T clusters where every cluster has Q grids (Q << M) (calculating Posterior Probability only on Q grids out of M grids) ?
Thank you very much.
Discrete sum and product
can be understood as loops. If you are happy with floating point approximation most other operators are typically O(1), conditional probability looks like a function call. Just inject constants and variables in your equation and you'll get the expected Big-O, the details of formula are irrelevant. Also be aware that these "loops" can often be simplified using mathematical properties.
If the result is not obvious, please convert your above mathematical formula in actual programming code in a programming language. Computer Science Big-O is never about a formula but about an actual translation of it in programming steps, depending on the implementation the same formula can lead to very different execution complexities. As different as adding integers by actually performing sum O(n) or applying Gauss formula O(1) for instance.
By the way why are you doing a discrete sum on a discrete domaine N ? Shouldn't it be M ?

How Fast is Convolution Using FFT

I read that in order to compute the convolution of two signals x,y (1D for example), the naïve method takes O(NM).
However FFT is used to compute FFT^-1(FFT(x)FFT(y)), which takes O(N log(N)), in the case where N>M.
I wonder why is this complexity considered better than the former one, as M isn't necessarily bigger than log(N). Moreover, M is very often the length of a filter, which doesn't scale with the signal to be filtered, and will actually provide us with a complexity more similar to O(N) than to O(N^2).
Fast convolution in the frequency domain is typically more efficient than direct convolution when the size of the filter exceeds a particular threshold. So for relatively small filters direct convolution is more efficient, whereas for longer filters there comes a point at which FFT-based convolution is more efficient. The actual value of m for this "tipping point" depends on a lot of factors, but it's typically somewhere between 10 and 100.

Check how fast numpy.linalg.lstsq is finding convergence

I have a question concerning NumPy module linalg.lstsq(a,b). There is any possibility to check how fast this method is finding convergence? I mean some of characteristics which indicate how fast computation is going to convergence?
Thank you in advanced for brain storm.
The Numpy function linalg.lstsq uses singular value decomposition (SVD) to solve the least-square problem. Thus, if your matrix A is n by n, it will require n^3 flops.
More precisely, I think that the function uses the Householder Bidiagonalization to compute the SVD and so, if your matrix is m by n, the complexity will be O(max(m, n) * min(m, n)^2).