from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import os.path
from datetime import datetime
from PIL import Image
import numpy as np
from io import BytesIO
import requests
import base64

import tensorflow as tf
from tensorflow.python.platform import gfile
import captcha_model as captcha

import config

IMAGE_WIDTH = config.IMAGE_WIDTH
IMAGE_HEIGHT = config.IMAGE_HEIGHT

CHAR_SETS = config.CHAR_SETS
CLASSES_NUM = config.CLASSES_NUM
CHARS_NUM = config.CHARS_NUM

SAMPLE_IMAGE =  "/9j/4AAQSkZJRgABAQEAYABgAAD//gA7Q1JFQVRPUjogZ2QtanBlZyB2MS4wICh1c2luZyBJSkcgSlBFRyB2NjIpLCBxdWFsaXR5ID0gMTMK/9sAQwA9Ki42LiY9NjI2RUE9SVyaZFxUVFy8ho5vmt/E6ubbxNfT9v////b////T1////////////+7//////////////9sAQwFBRUVcUVy0ZGS0//3X/f///////////////////////////////////////////////////////////////////8AAEQgAFAAtAwEiAAIRAQMRAf/EAB8AAAEFAQEBAQEBAAAAAAAAAAABAgMEBQYHCAkKC//EALUQAAIBAwMCBAMFBQQEAAABfQECAwAEEQUSITFBBhNRYQcicRQygZGhCCNCscEVUtHwJDNicoIJChYXGBkaJSYnKCkqNDU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6g4SFhoeIiYqSk5SVlpeYmZqio6Slpqeoqaqys7S1tre4ubrCw8TFxsfIycrS09TV1tfY2drh4uPk5ebn6Onq8fLz9PX29/j5+v/EAB8BAAMBAQEBAQEBAQEAAAAAAAABAgMEBQYHCAkKC//EALURAAIBAgQEAwQHBQQEAAECdwABAgMRBAUhMQYSQVEHYXETIjKBCBRCkaGxwQkjM1LwFWJy0QoWJDThJfEXGBkaJicoKSo1Njc4OTpDREVGR0hJSlNUVVZXWFlaY2RlZmdoaWpzdHV2d3h5eoKDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uLj5OXm5+jp6vLz9PX29/j5+v/aAAwDAQACEQMRAD8A0aan8Q9D/wDXpJApGXPyjkjsajRfvDG0ODgen+c0CY5pgFcr1AJGehxT1ARAueFGOagbcfKjKkbWHPbgZ/pUzoXxzwOoI60MLuwqOHBK9AcU6o4QRvz/AHj2qSgFsNZFfG7PHTmkKhWXGeT3OexoooBiN/r0H+yT/KpKKKBhRRRQB//Z"

FLAGS = None

def one_hot_to_texts(recog_result):
  texts = []
  for i in xrange(recog_result.shape[0]):
    index = recog_result[i]
    texts.append(''.join([CHAR_SETS[i] for i in index]))
  return texts


def input_data(image_dir):
   batch_size = 1
   images = np.zeros([batch_size, IMAGE_HEIGHT*IMAGE_WIDTH], dtype='float32')
   if len(sys.argv) > 1:
       image = Image.open(BytesIO(base64.b64decode(sys.argv[1])))
   else:
       image = Image.open(BytesIO(base64.b64decode(SAMPLE_IMAGE)))
   image_gray = image.convert('L')
   image_resize = image_gray.resize(size=(IMAGE_WIDTH,IMAGE_HEIGHT))
   image.save('image.jpg')
   image.close()

   input_img = np.array(image_resize, dtype='float32')
   input_img = np.multiply(input_img.flatten(), 1./255) - 0.5
   images[0,:]=input_img
   return images, ['image.jpeg']


def run_predict():
#    tf.logging.set_verbosity(tf.logging.ERROR)
    with tf.Graph().as_default(), tf.device('/cpu:0'):
#       tf.logging.set_verbosity(tf.logging.ERROR)
        input_images, input_filenames = input_data(FLAGS.captcha_dir)
        images = tf.constant(input_images)
        logits = captcha.inference(images, keep_prob=1)
        result = captcha.output(logits)
        saver = tf.train.Saver()
        sess = tf.Session()
        saver.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir))
        recog_result = sess.run(result)
        sess.close()
        text = one_hot_to_texts(recog_result)
        print('captcha_code_is:'+text[0]+'!')

def main(_):
  run_predict()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--checkpoint_dir',
        type=str,
        default='./captcha_train',
        help='Directory where to restore checkpoint.'
    )
    parser.add_argument(
        '--captcha_dir',
        type=str,
        default='./data/test_data',
        help='Directory where to get captcha images.'
    )
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
