import torch
import os
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision.utils as vutils

from utils.files_lib import get_full_path
from cartoon_api.network.Transformer import Transformer


class GanCartoonReadyModelHelper:
    valid_ext = ['.jpg', '.png', '.jpeg']

    def __init__(self, input_file, style, gpu=-1, output_dir="test_output"):
        self.inputFile = input_file
        self.output_dir = output_dir
        self.load_size = 450
        self.model_path = get_full_path("cartoon_api", "pretrained_model")
        self.style = style
        self.gpu = gpu
        # load pretrained model
        self.model = Transformer()
        self.model.load_state_dict(torch.load(os.path.join(self.model_path, self.style + '_net_G_float.pth')))
        # model.load_state_dict(torch.load(os.path.join(opt.model_path,   'download.pth')))
        self.model.eval()

        if self.gpu > -1:
            print('GPU mode')
            self.model.cuda()
        else:
            print('CPU mode')
            self.model.float()

    def is_valid_file(self):
        ext = os.path.splitext(self.inputFile)[1]
        return ext in self.valid_ext

    def applyFilter(self) -> str:
        if not os.path.exists(self.output_dir): os.mkdir(self.output_dir)
        # load image
        input_image = Image.open(self.inputFile).convert("RGB")
        # resize image, keep aspect ratio
        h = input_image.size[0]
        w = input_image.size[1]
        ratio = h * 1.0 / w
        if ratio > 1:
            h = self.load_size
            w = int(h * 1.0 / ratio)
        else:
            w = self.load_size
            h = int(w * ratio)
        input_image = input_image.resize((h, w), Image.BICUBIC)
        input_image = np.asarray(input_image)
        # RGB -> BGR
        input_image = input_image[:, :, [2, 1, 0]]
        input_image = transforms.ToTensor()(input_image).unsqueeze(0)
        # preprocess, (-1, 1)
        input_image = -1 + 2 * input_image
        if self.gpu > -1:
            input_image = Variable(input_image, volatile=True).cuda()
        else:
            input_image = Variable(input_image, volatile=True).float()
        # forward
        output_image = self.model(input_image)
        output_image = output_image[0]
        # BGR -> RGB
        output_image = output_image[[2, 1, 0], :, :]
        # deprocess, (0, 1)
        output_image = output_image.data.cpu().float() * 0.5 + 0.5
        # save
        # output_file = os.path.join(self.output_dir, self.inputFile[:-4] + '_' + self.style + '.jpg')
        # print("out-put",self.get_output_path(newName=self.style + '.jpg'))
        output_file = os.path.join(self.output_dir,self.get_output_path(newName=self.style + '.png'))
        vutils.save_image(output_image, output_file)
        return output_file

    def get_output_path(self, newName = 'new_ot'):
        # print("self.inputFile:", self.inputFile)
        # print("self.inputFile[:-3]:", self.inputFile[:-3])
        # print("self.inputFile[:-4]:", self.inputFile[:-4])
        # print("self.inputFile[:-5]:", self.inputFile[:-5])
        head, tail = os.path.split(self.inputFile)
        # print("tail:", tail[:-4])
        # print("head:", head)
        # print("return:", '{}_{}'.format(tail[:-4],newName))
        # return '{}{}'.format(tail,newName)
        return '{}_{}'.format(tail[:-4],newName)


        # print("self.inputFile:", self.inputFile)
        # print("filename:",self.inputFile[:-5] + '_' + self.style + '.jpg')
        #
        # print("self.output_dir:",self.output_dir)
        # output_file = os.path.join(self.output_dir, self.inputFile[:-4] + '_' + self.style + '.jpg')
        # print(output_file)
