# Pattern3.py

from gpanel import *
from operator import itemgetter
import time
import random

def onMousePressed(x, y):
    global isIdle, startDrawing
    if isLeftMouseButton():
        if x < -100 or x > 100 or y < -100 or y > 100:
            return
        if startDrawing:
            initGraphics()
            startDrawing = False
        pos(x, y)
    if isRightMouseButton():
        isIdle = False
        title("Working. Please wait...")
        pic = shrinkPicture()
        showPicture(pic, 160, -180)
        perceived = predict(pic)
        title("Perceived digit: " + str(perceived) + " . Draw next!")
        startDrawing = True
        isIdle = True

def onMouseDragged(x, y):
    if x < -100 or x > 100 or y < -100 or y > 100 \
        or startDrawing or not isIdle:
        return
    draw(x, y)

def euklidian(pic1, pic2):
    # pic = [[0, 1, 0,...,1],[...],...]
    count = 0
    for i in range(20):
        for k in range(20):
            count += int(pic1[i][k] != pic2[i][k])
    return count

def extractPicture(img, n):
    # n = 0..5000
    # return list with 0, 1
    pic = [[0 for i in range(20)] for k in range(20)]
    x_offset = 20 * (n % 100)
    y_offset = 20 * (n // 100)
    for x in range(20):
        for y in range(20):
            if isTigerJython:
                c = img.getPixelColor(x + x_offset, y + y_offset)
                lum = (c.getRed() + c.getGreen() + c.getBlue()) // 3
            else:
                c = GPanel.getPixelColor(img, x + x_offset, y + y_offset)
                lum = (c[0] + c[1] + c[2]) // 3
            pic[x][y] = int(lum > 127)
    return pic

def getDigit(n):
    return n // 500

def loadData(filename):
    img = getImage(filename)
    out = [0] * 5000
    for n in range(5000):
        out[n] = extractPicture(img, n)
    return out

def predict(pic0):
    distances = []
    for n in range(0, 5000):
        pic = samples[n]
        distance = euklidian(pic0, pic)
        distances.append([n, distance])
    sorted_distances = sorted(distances, key = itemgetter(1))
    nearestPic = sorted_distances[0][0]
    return getDigit(nearestPic)

def shrinkPicture():
    scaleFactor = 1 / 10.0
    if isTigerJython:
        img = getBitmap()
        cropped = GBitmap.crop(img, 100, 100, 300, 300)
        img = GBitmap.scale(cropped, scaleFactor, 0)
    else:
        img = getFullImage()
        cropped = GPanel.crop(img, 100, 100, 300, 300)
        img = GPanel.scale(cropped, scaleFactor)
    pic = toPicture(img)
    return pic

def toPicture(img):
    pic = [[0 for i in range(20)] for k in range(20)]
    for x in range(20):
        for y in range(20):
            if isTigerJython:
                c = img.getPixelColor(x, y)
                lum = (c.getRed() + c.getGreen() + c.getBlue()) // 3
            else:
                c = GPanel.getPixelColor(img, x, y)
                lum = (c[0] + c[1] + c[2]) // 3
            pic[x][y] = int(lum > 127)
    return pic

def showPicture(pic, xpos, ypos):
    if isTigerJython:
        showPictureTJ(pic, xpos, ypos)
    else:
        showPicturePy(pic, xpos, ypos)

def showPicturePy(pic, xpos, ypos):
    white = QColor(255, 255, 255).rgb()
    black = QColor(0, 0, 0).rgb()
    pm = QPixmap(20, 20)
    img = pm.toImage()
    for x in range(20):
        for y in range(20):
            if pic[x][y] == 1:
                img.setPixel(x, y, white)
            else:
                img.setPixel(x, y, black)
    image(img, xpos, ypos)

def showPictureTJ(pic, xpos, ypos):
    bm = GBitmap(20, 20)
    for x in range(20):
        for y in range(20):
            if pic[x][y] == 1:
                bm.setPixelColorStr(x, y, "white")
            else:
                bm.setPixelColorStr(x, y, "black")
    image(bm, xpos, ypos)

def initGraphics():
    lineWidth(1)
    bgColor("white")
    setColor("black")
    fillRectangle(-110, -110, 110, 110)
    text(-180, -150, "Examples:")
    setColor("white")
    lineWidth(20)
    setColor("white")
    for i in range(10):
        n = random.randint(500 * i, 500 * (i + 1) - 1)
        showPicture(samples[n], -180 + i * 22, -180)
    title("Draw a digit and click right!")

makeGPanel(Size(400, 400),
    mousePressed = onMousePressed,
    mouseDragged = onMouseDragged)
window(-200, 200, -200, 200)
title("Loading data. Please wait...")
samples = loadData("digits.png")
initGraphics()
isIdle = True
startDrawing = True
keep()
