JAX¶

Christian Elsasser (che@physik.uzh.ch)

The content of this lecture might be distributed under CC by-sa.

In [1]:
import jax # JAX functionality
import jax.numpy as jnp # Numpy array of JAX
import numpy as np # Regular numpy array

We create a large numpy array

In [2]:
a = np.random.randn(10_000,10_000)
In [3]:
j = jnp.array(a)

We create a third-order polynomial function and its (jit-)compiled version of the form $$f(x) = -4x^3+9x^2+6x-3$$

In [4]:
def f(x):
    return -4*x*x*x + 9*x*x + 6*x - 3

@jax.jit # <- this compiles the function 
def f_compiled(x):
    return -4*x*x*x + 9*x*x + 6*x - 3
# we can also do f_c = jax.jit(f)

Let's benchmark the different cases

In [5]:
%timeit f(a)                                         # NumPy
%timeit f(jnp.array(a)).block_until_ready()          # JAX
%timeit f_compiled(jnp.array(a)).block_until_ready() # JAX + JIT pre-compiled
%timeit jax.jit(f)(jnp.array(a)).block_until_ready() # JAX + JIT 

# Since JAX is operating asynchronously we need to ensure that the command is
# conclude before the next loop is started,
# therefore we have the block_until_ready function call.
1.82 s ± 7.69 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
643 ms ± 24.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
209 ms ± 1.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
210 ms ± 1.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

We assumed that for the JAX cases the jnp.array has to be only created for this operation which is rather conservative and disfavourably for JAX.
So let's test it with the array already created

In [7]:
j = jnp.array(a)
In [8]:
%timeit f(j).block_until_ready()          # JAX
%timeit f_compiled(j).block_until_ready() # JAX + JIT pre-compiled
%timeit jax.jit(f)(j).block_until_ready() # JAX + JIT 
498 ms ± 5.15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
67.2 ms ± 295 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
67.4 ms ± 58.3 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [9]:
%timeit np.exp(j)
%timeit jnp.exp(j).block_until_ready()
153 ms ± 1.88 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
59 ms ± 232 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [10]:
%timeit (2*a+3)
%timeit (2*j+3).block_until_ready()
499 ms ± 23.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
86 ms ± 377 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Auto-differentiation¶

We are going to compare the auto-differentiation functionality in JAX with standard analytical and numerical derivatives (the latter using numdifftools).

As benchmarking function we still use the third-order polynomial

$$f(x) = -4x^3+9x^2+6x-3$$

with the analytical derivative

$$f'(x) = -12x^2+18x+6$$

In [11]:
import numdifftools as nd
In [12]:
# Analytical
df_a = lambda x : -12*x*x + 18*x + 6

# Analytical + JIT
df_a_comp = jax.jit(df_a)

# Numerical
df_n = nd.Derivative(f)

# Auto-differentiation
df_j = jax.grad(f)

# Auto-dfferentiation + JIT
df_j_comp = jax.jit(df_j)
In [13]:
b = np.random.randn(1_000,1_000)
In [14]:
%timeit df_a(b) # Analytical derivative
%timeit df_a_comp(b) # Analytical derivative compiled

%timeit df_n(b) # Numerical derivative

# We need to verctorize the auto-differntitation functions since we apply it to an array
%timeit jnp.vectorize(df_j)(jnp.array(b)).block_until_ready() # Auto-diff
%timeit jnp.vectorize(df_j_comp)(jnp.array(b)).block_until_ready() # Auto-diff compiled
5.31 ms ± 57.9 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.4 ms ± 5.53 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
4.92 s ± 31.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
29.3 ms ± 5.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
5.22 ms ± 36.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [ ]: