from keras.preprocessing import image
from keras.preprocessing.image import array_to_img, img_to_array
from keras.models import load_model
import numpy as np
import cv2

saved_model = load_model("data/cnn/model.h5", compile=False)
saved_model.make_predict_function()


def crop_id_card(image):
    # Load the image
    #image = cv2.imread(image_path)

    # Check if the image is loaded properly
    if image is None:
        print("Error: Unable to load image. Please check the file path.")
        return None

    # Convert to grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # Apply GaussianBlur to reduce noise and improve edge detection
    blurred = cv2.GaussianBlur(gray, (5, 5), 0)

    # Perform edge detection
    edges = cv2.Canny(blurred, 50, 150)

    # Find contours in the edge-detected image
    contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Sort contours by area and find the largest one
    contours = sorted(contours, key=cv2.contourArea, reverse=True)

    # Iterate over contours and find the approximate shape of the largest rectangle
    id_card_contour = None
    for contour in contours:
        epsilon = 0.02 * cv2.arcLength(contour, True)
        approx = cv2.approxPolyDP(contour, epsilon, True)
        if len(approx) == 4:  # A rectangle has 4 sides
            id_card_contour = approx
            break

    # Check if the ID card contour was found
    if id_card_contour is not None:
        # Get the bounding box of the contour
        x, y, w, h = cv2.boundingRect(id_card_contour)

        # Crop the image using the bounding box coordinates
        cropped_id_card = image[y:y+h, x:x+w]

        # Convert cropped image to PIL format
        cropped_image = Image.fromarray(cv2.cvtColor(cropped_id_card, cv2.COLOR_BGR2RGB))

        return cropped_image
    else:
        print("ID card not found in the image.")
        return None


def main(image):
    print(image) 
    #image1  = crop_id_card(image)
    img = image.resize((150, 150))
    img = np.array(img)
    img = np.expand_dims(img, axis=0)
    prediction = saved_model.predict(img)
    print('prediction:')

    print(prediction)

    # 0 means KTP is detected
    return prediction[0][0] == 0

if __name__ == '__main__':
    main(sys.argv[1])