#!/usr/bin/env python

from PIL import Image, ImageOps
import numpy as np
import matplotlib.pyplot as plt
from time import time


im = Image.open('monet.jpg')
# In general dithering works with color as well, for simplicity we look only at grayscale images.
# Therefore convert it to 8-bit grayscale (value 0 to 255)
grayscale = ImageOps.invert(im.convert('L'))
image_data = np.array(grayscale, dtype=float)

# in case the image is too big you might only look a certain window
# image_data = image_data[760:1020, 600:1000]

# you can also load just a gradient image to understand more easily what is going on
# image_data = np.tile(np.linspace(0, 255, 256), (16,1))

#  Ordered Dither https://en.wikipedia.org/wiki/Ordered_dithering
index_matrix = np.array([[5, 9, 6, 10], [13, 1, 14, 2], [7, 11, 4, 8], [15, 3, 12, 0]])
threshold_matrix = (index_matrix+1)/(index_matrix.size+1)*255
index_matrix_length = index_matrix.shape[0]


def ordered_dither(original_image):
    rows, cols = original_image.shape
    new_image = np.zeros(original_image.shape, dtype=np.uint8)
    for row in range(rows):
        for col in range(cols):
            new_image[row][col] = 1 if original_image[row, col] > threshold_matrix[row % index_matrix_length][col % index_matrix_length] else 0
    return new_image


def optimized_ordered_dither(original_image):
    rows, cols = original_image.shape
    optimized_new_image = np.zeros(original_image.shape, dtype=np.uint8)
    # instead of repetitive cycling through the threshold matrix we repeatably blow it up to the size of the image
    # this is done with tiling (np.tile). But we might get too big. We tile to have at least the same size or bigger
    extended_treshold_matrix = np.tile(threshold_matrix, (rows//threshold_matrix.shape[0]+1, cols//threshold_matrix.shape[1]+1))
    # in case we are too big, we just cut it back again to fit the original image dimension
    extended_treshold_matrix = extended_treshold_matrix[0:rows, 0:cols]
    # the two for loops from above are now a one-liner. Again we let numpy do the cache-magic.
    optimized_new_image[original_image > extended_treshold_matrix] = 1
    return optimized_new_image


plt.figure('Original Grayscale')
plt.imshow(image_data, vmin=0, vmax=0xff, cmap=plt.get_cmap('binary'))


# the slow dithering algorithm
print('dithering with original algorithm ...')
start_time = time()
ordered_dither_image = ordered_dither(image_data)
stop_time = time()
original_timedelta = stop_time - start_time
print(f'ordered dither took {round(original_timedelta, 3)} sec')
plt.figure('ordered Dithered image')
plt.imshow(ordered_dither_image, vmin=0, vmax=1, cmap=plt.get_cmap('binary'))


print('dithering with your optimized algorithm ...')
start_time = time()
optimized_ordered_dither_image = optimized_ordered_dither(image_data)
stop_time = time()
optimized_timedelta = stop_time - start_time
print(f'optimized_ordered took {round(optimized_timedelta, 3)} sec')

if not np.array_equal(ordered_dither_image, optimized_ordered_dither_image):
    print('\n\n\n')
    print('! WARNING ! Something went wrong. Your image is different. We will plot it anyways\n\n\n')

plt.figure('your optimized ordered Dithered image')
plt.imshow(optimized_ordered_dither_image, vmin=0, vmax=1, cmap=plt.get_cmap('binary'))

print(f'speedup factor is {original_timedelta/optimized_timedelta:.2f}')

plt.show()
