"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import sys
import six
import cv2
import numpy as np


class GenTableMask(object):
    """ gen table mask """

    def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
        self.shrink_h_max = 5
        self.shrink_w_max = 5
        self.mask_type = mask_type

    def projection(self, erosion, h, w, spilt_threshold=0):
        # 水平投影
        projection_map = np.ones_like(erosion)
        project_val_array = [0 for _ in range(0, h)]

        for j in range(0, h):
            for i in range(0, w):
                if erosion[j, i] == 255:
                    project_val_array[j] += 1
        # 根据数组，获取切割点
        start_idx = 0  # 记录进入字符区的索引
        end_idx = 0  # 记录进入空白区域的索引
        in_text = False  # 是否遍历到了字符区内
        box_list = []
        for i in range(len(project_val_array)):
            if in_text == False and project_val_array[
                    i] > spilt_threshold:  # 进入字符区了
                in_text = True
                start_idx = i
            elif project_val_array[
                    i] <= spilt_threshold and in_text == True:  # 进入空白区了
                end_idx = i
                in_text = False
                if end_idx - start_idx <= 2:
                    continue
                box_list.append((start_idx, end_idx + 1))

        if in_text:
            box_list.append((start_idx, h - 1))
        # 绘制投影直方图
        for j in range(0, h):
            for i in range(0, project_val_array[j]):
                projection_map[j, i] = 0
        return box_list, projection_map

    def projection_cx(self, box_img):
        box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
        h, w = box_gray_img.shape
        # 灰度图片进行二值化处理
        ret, thresh1 = cv2.threshold(box_gray_img, 200, 255,
                                     cv2.THRESH_BINARY_INV)
        # 纵向腐蚀
        if h < w:
            kernel = np.ones((2, 1), np.uint8)
            erode = cv2.erode(thresh1, kernel, iterations=1)
        else:
            erode = thresh1
        # 水平膨胀
        kernel = np.ones((1, 5), np.uint8)
        erosion = cv2.dilate(erode, kernel, iterations=1)
        # 水平投影
        projection_map = np.ones_like(erosion)
        project_val_array = [0 for _ in range(0, h)]

        for j in range(0, h):
            for i in range(0, w):
                if erosion[j, i] == 255:
                    project_val_array[j] += 1
        # 根据数组，获取切割点
        start_idx = 0  # 记录进入字符区的索引
        end_idx = 0  # 记录进入空白区域的索引
        in_text = False  # 是否遍历到了字符区内
        box_list = []
        spilt_threshold = 0
        for i in range(len(project_val_array)):
            if in_text == False and project_val_array[
                    i] > spilt_threshold:  # 进入字符区了
                in_text = True
                start_idx = i
            elif project_val_array[
                    i] <= spilt_threshold and in_text == True:  # 进入空白区了
                end_idx = i
                in_text = False
                if end_idx - start_idx <= 2:
                    continue
                box_list.append((start_idx, end_idx + 1))

        if in_text:
            box_list.append((start_idx, h - 1))
        # 绘制投影直方图
        for j in range(0, h):
            for i in range(0, project_val_array[j]):
                projection_map[j, i] = 0
        split_bbox_list = []
        if len(box_list) > 1:
            for i, (h_start, h_end) in enumerate(box_list):
                if i == 0:
                    h_start = 0
                if i == len(box_list):
                    h_end = h
                word_img = erosion[h_start:h_end + 1, :]
                word_h, word_w = word_img.shape
                w_split_list, w_projection_map = self.projection(word_img.T,
                                                                 word_w, word_h)
                w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
                if h_start > 0:
                    h_start -= 1
                h_end += 1
                word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
                split_bbox_list.append([w_start, h_start, w_end, h_end])
        else:
            split_bbox_list.append([0, 0, w, h])
        return split_bbox_list

    def shrink_bbox(self, bbox):
        left, top, right, bottom = bbox
        sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
        sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
        left_new = left + sh_w
        right_new = right - sh_w
        top_new = top + sh_h
        bottom_new = bottom - sh_h
        if left_new >= right_new:
            left_new = left
            right_new = right
        if top_new >= bottom_new:
            top_new = top
            bottom_new = bottom
        return [left_new, top_new, right_new, bottom_new]

    def __call__(self, data):
        img = data['image']
        cells = data['cells']
        height, width = img.shape[0:2]
        if self.mask_type == 1:
            mask_img = np.zeros((height, width), dtype=np.float32)
        else:
            mask_img = np.zeros((height, width, 3), dtype=np.float32)
        cell_num = len(cells)
        for cno in range(cell_num):
            if "bbox" in cells[cno]:
                bbox = cells[cno]['bbox']
                left, top, right, bottom = bbox
                box_img = img[top:bottom, left:right, :].copy()
                split_bbox_list = self.projection_cx(box_img)
                for sno in range(len(split_bbox_list)):
                    split_bbox_list[sno][0] += left
                    split_bbox_list[sno][1] += top
                    split_bbox_list[sno][2] += left
                    split_bbox_list[sno][3] += top

                for sno in range(len(split_bbox_list)):
                    left, top, right, bottom = split_bbox_list[sno]
                    left, top, right, bottom = self.shrink_bbox(
                        [left, top, right, bottom])
                    if self.mask_type == 1:
                        mask_img[top:bottom, left:right] = 1.0
                        data['mask_img'] = mask_img
                    else:
                        mask_img[top:bottom, left:right, :] = (255, 255, 255)
                        data['image'] = mask_img
        return data


class ResizeTableImage(object):
    def __init__(self, max_len, resize_bboxes=False, infer_mode=False,
                 **kwargs):
        super(ResizeTableImage, self).__init__()
        self.max_len = max_len
        self.resize_bboxes = resize_bboxes
        self.infer_mode = infer_mode

    def __call__(self, data):
        img = data['image']
        height, width = img.shape[0:2]
        ratio = self.max_len / (max(height, width) * 1.0)
        resize_h = int(height * ratio)
        resize_w = int(width * ratio)
        resize_img = cv2.resize(img, (resize_w, resize_h))
        if self.resize_bboxes and not self.infer_mode:
            data['bboxes'] = data['bboxes'] * ratio
        data['image'] = resize_img
        data['src_img'] = img
        data['shape'] = np.array([height, width, ratio, ratio])
        data['max_len'] = self.max_len
        return data


class PaddingTableImage(object):
    def __init__(self, size, **kwargs):
        super(PaddingTableImage, self).__init__()
        self.size = size

    def __call__(self, data):
        img = data['image']
        pad_h, pad_w = self.size
        padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
        height, width = img.shape[0:2]
        padding_img[0:height, 0:width, :] = img.copy()
        data['image'] = padding_img
        shape = data['shape'].tolist()
        shape.extend([pad_h, pad_w])
        data['shape'] = np.array(shape)
        return data
