import os
from collections import Counter

from bs4 import BeautifulSoup
import numpy as np
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '2'
import tensorflow as tf
import cv2

# Get the current directory of the script
current_dir = os.path.dirname(os.path.abspath(__file__))

# Construct the path to the assets folder
assets_dir = os.path.join(current_dir, '..', 'assets')
font_model_path = os.path.join(assets_dir, 'font_model.h5')

class FontDetector:
    # Class for detecting the font class (Serif, Sans Serif) using an image of the text.
    # Font information is added to the hocr file. 

    def __init__(self):
        self.font_model = tf.keras.models.load_model(font_model_path)

    def _get_prediction(self, img):
        img = np.array([img])
        prediction = self.font_model.predict(img, verbose=0)[0][0]
        return 'sans' if prediction < 0.5 else 'serif'

    def _get_img(self, left, top, width, height, img):
        #img is the full page image.  
        #sub_img is a portion of the img that contain a single word.
        sub_img = img[top: height, left: width, :]
        sub_img = cv2.resize(sub_img, (240, 120))
        gray_img = cv2.cvtColor(sub_img, cv2.COLOR_BGR2GRAY)
        final_img = gray_img[..., np.newaxis]
        return final_img

    def _get_paragraph_font(self, paragraph, img):
        # The font classification model is run on half of the words in a paragraph and the most frequently occuring font class is returned.
        get_coordinate = lambda x: [int(i) for i in x.split(";")[0].split(' ')[-4:]]
        coordinates = [get_coordinate(word.get('title')) for word in paragraph.find_all('span', {'class': 'ocrx_word'})
                        if (word.get_text().strip())]
        if(not coordinates):
            return False
        length = int(len(coordinates)/2)
        img_data = [self._get_img(*[*coordinate, img]) for coordinate in coordinates]
        predictions = [self._get_prediction(img_data) for img_data in img_data[:length]]
        if(not predictions):
            return False
        most_common = Counter(predictions).most_common()[0][0]
        return most_common

    def get_updated_hocr_data(self, hocr_data, img_path):
        img = cv2.imread(img_path)
        soup = BeautifulSoup(hocr_data, 'xml')

        for paragraph in soup.find_all('p', {'class': 'ocr_par'}):
            font_class = self._get_paragraph_font(paragraph, img)
            if(not font_class):
                continue
            for word in paragraph.find_all('span', {'class': 'ocrx_word'}):
                if(word.get_text().strip()):
                    # Adding a new class variable called font_class to hocr file(XML)
                    word['font_class'] = font_class
                    soup.find('span', {'id': word['id']}).replaceWith(word)

        return soup.prettify()

# if(__name__ == '__main__'):
#     font_detector = FontDetector()
#     hocr_file = '../../Test Files/Testing/X_bl.hocr'
#     img_file = '../../Test Files/Testing/X_bl.png'
#     file_name = font_detector.get_hocr_file(hocr_file, img_file)
#     print(file_name)