#!/usr/bin/env python

#######################################################################
# This script compares the speed of the computation of a polynomial
# for different libraries: numpy, numexpr and numba.
#
# Author: Francesc Alted, Roman Gredig
# Date: 2020-06-11
#######################################################################

import argparse
import math
from numba import double, jit
from time import time
import numpy as np
import numexpr as ne

N = 10_000_000           # number of points to evaluate
x = np.linspace(-1, 1, N)  # vector x in range [-1, 1]

# The different expressions supported
expr = [
    ".25*x**3 + .75*x**2 - 1.5*x - 2",  # 0) the polynomial to compute
    "((.25*x + .75)*x - 1.5)*x - 2",    # 1) a computer-friendly polynomial
    "x",                                # 2) the identity function
    "sin(x)**2 + cos(x)**2",            # 3) a transcendental function
    ]

parser = argparse.ArgumentParser()
parser.add_argument('--expression-index', type=int, default=0, choices=range(len(expr)),
                    help='select the expression to compute')
args = parser.parse_args()

to_compute = args.expression_index


# A function that is going to be accelerated by numba
def poly(x):
    y = np.empty(N, dtype=np.float64)
    for i in range(N):
        if to_compute == 0:
            y[i] = 0.25*x[i]**3 + 0.75*x[i]**2 + 1.5*x[i] - 2
        elif to_compute == 1:
            y[i] = ((0.25*x[i] + 0.75)*x[i] + 1.5)*x[i] - 2
        elif to_compute == 2:
            y[i] = x[i]
        elif to_compute == 3:
            y[i] = math.sin(x[i])**2 + math.cos(x[i])**2
    return y


print("Using expression", expr[to_compute], "with", N, "points")
print()
print("*** Running numpy!")
start = time()
if "sin" in expr[to_compute]:
    y = np.sin(x)**2 + np.cos(x)**2
elif "x" == expr[to_compute]:
    # Trick to force a copy with NumPy
    y = x.copy()
else:
    y = eval(expr[to_compute])
tnumpy = time() - start
print("Result from numpy is", y, "in", round(tnumpy,3), "sec")

print()
print("*** Running numexpr!")
start = time()
ne.set_num_threads(1)   # change the number of threads if you want
y = ne.evaluate(expr[to_compute], optimization='aggressive')
tnumexpr = time() - start
print("Result from numexpr is", y, "in", round(tnumexpr, 3), "sec")

print()
print("*** Running numba!")
start = time()
cpoly = jit(double[:](double[:]))(poly)
tcompile = time() - start
print("Compilation time for numba:", round(tcompile, 3))

start = time()
cpoly(x)
tnumba = time() - start
print("Result from numba is", y, "in", round(tnumba,3), "sec")

print()
print("*** Running poly with native python! This might take a while.")
start = time()
poly(x)
tpython = time() - start
print("Result from python is", y, "in",round(tpython, 3), "sec")


print()
print("*** Speedup summary:")
print("numexpr vs numpy speedup is" ,(tnumpy / tnumexpr))
print("numba vs numpy speedup is", (tnumpy / (tcompile + tnumba)))
print("numba vs python speedup is", (tpython / tnumpy))
print()
