#!/usr/bin/env python
#######################################################################
# This script compares the speed of the computation of a polynomial
# for different in-memory libraries: numpy and numexpr.
#
# Author: Francesc Alted, Roman Gredig
# Date: 2020-06-11-04
#######################################################################

import argparse
from time import time
import numpy as np
import numexpr as ne

N = 10_000_000             # number of points to evaluate
x = np.linspace(-1, 1, N)  # vector x in range [-1, 1]

# The different expressions supported. Select them with the --expression-index argument
expressions = [
    ".25*x**3 + .75*x**2 - 1.5*x - 2",  # 0) the polynomial to compute
    "((.25*x + .75)*x - 1.5)*x - 2",    # 1) a computer-friendly polynomial
    "x",                                # 2) the identity function
    "sin(x)**2 + cos(x)**2",            # 3) a transcendental function
    ]


def compute(expression_index, library):
    """Compute the polynomial using different methods."""
    expr_ = expressions[expression_index]
    if library == "numpy":
        if "sin" in expr_:
            # Trick to allow numpy evaluate this
            expr_ = "np.sin(x)**2+np.cos(x)**2"
        elif expr_ == "x":
            # Trick to force a copy with NumPy
            y = x.copy()
        y = eval(expr_)
    else:
        ne.set_num_threads(1)  # force numexpr to use only 1 thread
        y = ne.evaluate(expr_)
    return y


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--expression-index', type=int, default=0, choices=range(len(expressions)),
                        help='select the expression to compute')
    parser.add_argument('--library', type=str, default='numpy', choices=['numpy', 'numexpr'],
                        help='select library to use')
    args = parser.parse_args()
    print(f"Computing: {expressions[args.expression_index]}, using {args.library} with {N} points")
    t0 = time()
    result = compute(args.expression_index, args.library)
    ts = round(time() - t0, 3)
    print(f"result: {result}")
    print(f"*** Time elapsed: {ts} sec")
