Christian Elsasser (che@physik.uzh.ch)
The content of this lecture might be distributed under CC by-sa.
import jax # JAX functionality
import jax.numpy as jnp # Numpy array of JAX
import numpy as np # Regular numpy array
We create a large numpy array
a = np.random.randn(10_000,10_000)
j = jnp.array(a)
We create a third-order polynomial function and its (jit-)compiled version
def f(x):
return -4*x*x*x + 9*x*x + 6*x - 3
@jax.jit # <- this compiles the function
def f_compiled(x):
return -4*x*x*x + 9*x*x + 6*x - 3
# we can also do f_c = jax.jit(f)
Let's benchmark the different cases
%timeit f(a) # NumPy
%timeit f(jnp.array(a)).block_until_ready() # JAX
%timeit f_compiled(jnp.array(a)).block_until_ready() # JAX + JIT pre-compiled
%timeit jax.jit(f)(jnp.array(a)).block_until_ready() # JAX + JIT
# Since JAX is operating asynchronously we need to ensure that the command is
# conclude before the next loop is started.
1.8 s ± 14.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 874 ms ± 5.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 287 ms ± 8.35 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 281 ms ± 4.54 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
We assumed that for the JAX cases the jnp.array has to be only created for this operation
which is rather conservative and disfavourably for JAX.
So let's test it with the array already created
j = jnp.array(a)
%timeit f(j).block_until_ready() # JAX
%timeit f_compiled(j).block_until_ready() # JAX + JIT pre-compiled
%timeit jax.jit(f)(j).block_until_ready() # JAX + JIT
665 ms ± 2.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 66.3 ms ± 347 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 66.6 ms ± 250 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit np.exp(j)
%timeit jnp.exp(j).block_until_ready()
153 ms ± 2.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 59.3 ms ± 184 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit (2*a+3)
%timeit (2*j+3).block_until_ready()
465 ms ± 2.79 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 127 ms ± 606 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
import numdifftools as nd
# Analytical
df_a = lambda x : -12*x*x + 18*x + 6
df_a_comp = jax.jit(df_a)
# Numerical
df_n = nd.Derivative(f)
# Auto-differentiation
df_j = jax.grad(f)
df_j_comp = jax.jit(df_j)
b = np.random.randn(1_000,1_00)
%timeit df_a(b) # Analytical derivative
%timeit df_a_comp(b) # Analytical derivative compiled
%timeit df_n(b) # Numerical derivative
%timeit jnp.vectorize(df_j)(jnp.array(b)).block_until_ready() # Auto-diff
%timeit jnp.vectorize(df_j_comp)(jnp.array(b)).block_until_ready() # Auto-diff compiled