#!/usr/bin/env python

#######################################################################
# This script compares the speed of the computation of a polynomial
# using multiple processes (numpy) or threads (numexpr).
#
# Author: Francesc Alted, Roman Gredig
# Date: 2020-06-11
#######################################################################

import argparse
from time import time
from multiprocessing import Pool
import numpy as np
import numexpr as ne

N = 10_000_000             # number of points to evaluate
x = np.linspace(-1, 1, N)  # vector x in range [-1, 1]

# The different expressions supported
expr = [
    ".25*x**3 + .75*x**2 - 1.5*x - 2",  # 0) the polynomial to compute
    "((.25*x + .75)*x - 1.5)*x - 2",    # 1) a computer-friendly polynomial
    "x",                                # 2) the identity function
    "sin(x)**2 + cos(x)**2",            # 3) a transcendental function
    ]


def compute(x, nt, expression_index, library):
    expr_ = expr[expression_index]
    if library == "numpy":
        y = compute_parallel(expr_, x, nt)
    else:
        ne.set_num_threads(nt)
        y = ne.evaluate(expr_)
    return y


def compute_block(expr_, xp, nt, i):
    x = xp[i*N//nt:(i+1)*N//nt]
    if "sin" in expr_:
        # Trick to allow numpy evaluate this
        expr_ = "np.sin(x)**2+np.cos(x)**2"
    elif expr_ == "x":
        # Trick to force a copy with numpy
        return x.copy(), nt, i
    y = eval(expr_)
    return y, nt, i


global result
result = np.empty(N, dtype='float64')


def cb(r):
    global result
    y, nt, i = r     # unpack return code
    result[i*N//nt:(i+1)*N//nt] = y   # assign the correct chunk


# Parallel computation for numpy via multiprocessing
def compute_parallel(expr_, x, nt):
    global result
    print("Computing with", nt, "threads in parallel:", expr_)
    po = Pool(processes=nt)
    for i in range(nt):
        po.apply_async(compute_block, (expr_, x, nt, i), callback=cb)
    po.close()
    po.join()
    return result


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--expression-index', type=int, default=0, choices=range(len(expr)),
                        help='select the expression to compute')
    parser.add_argument('--library', type=str, default='numpy', choices=['numpy', 'numexpr'],
                        help='select library to use')
    parser.add_argument('--threads', type=int, default=1, help='maximum number of threads be used in parallel')
    args = parser.parse_args()
    print(
        f"Computing: {expr[args.expression_index]}, "
        f"using {args.library} with maximum {args.threads} threads "
        f"and {int(N / 1e6)} million points"
    )
    for nt in range(args.threads):
        start_time = time()
        y = compute(x, nt+1, args.expression_index, args.library)
        stop_time = time()
        computing_time = round(stop_time - start_time, 3)
        print(y)
        print(f"*** Time elapsed for {nt+1} threads: {computing_time} sec")
