from pathlib import Path
import json

import numpy as np
from PIL import Image


DOCS_DIR = Path(__file__).resolve().parents[1]
PHOTO_PATH = DOCS_DIR / "GlassesGlareAndLiftingExample.jpg"
PATCH_PATH = DOCS_DIR / "glasses_lifting_patch" / "index=3&blendmode=linearLight&name=Glasses Anti Glare&layoutmode=patch.png"
JSON_PATH = DOCS_DIR / "glasses_lifting_patch" / "result.json"
NORMAL_OUTPUT_PATH = DOCS_DIR / "GlassesPatchAppliedNormal.jpg"
LINEAR_LIGHT_OUTPUT_PATH = DOCS_DIR / "GlassesPatchAppliedLinearLight.jpg"


def load_glasses_matrix(result_json_path):
    with result_json_path.open("r", encoding="utf-8") as file:
        data = json.load(file)

    for plugin in data["plugins"]:
        if plugin.get("pluginName") == "Glasses Anti Glare":
            face = plugin["faces"][0]
            return face["glassesAntiGlarePatchMatrix"]

    raise RuntimeError("Glasses Anti Glare metadata was not found")


def linear_light_blend(base_rgb, overlay_rgba):
    base = np.asarray(base_rgb, dtype=np.float32)
    overlay = np.asarray(overlay_rgba, dtype=np.float32)

    overlay_rgb = overlay[..., :3]
    overlay_alpha = overlay[..., 3:4] / 255.0
    linear_light = np.clip(base + 2.0 * overlay_rgb - 255.0, 0.0, 255.0)
    blended = base * (1.0 - overlay_alpha) + linear_light * overlay_alpha

    return Image.fromarray(np.clip(blended, 0, 255).astype(np.uint8), "RGB")


def normal_blend(base_rgb, overlay_rgba):
    base = base_rgb.convert("RGBA")
    composited = Image.alpha_composite(base, overlay_rgba)
    return composited.convert("RGB")


def main():
    photo = Image.open(PHOTO_PATH).convert("RGB")
    patch = Image.open(PATCH_PATH).convert("RGBA")
    matrix = load_glasses_matrix(JSON_PATH)

    # PIL affine transform coefficients map output pixels to input pixels.
    # The API matrix already maps original-photo coordinates to patch coordinates.
    coeffs = (
        matrix[0][0],
        matrix[0][1],
        matrix[0][2],
        matrix[1][0],
        matrix[1][1],
        matrix[1][2],
    )
    full_size_overlay = patch.transform(
        photo.size,
        Image.Transform.AFFINE,
        coeffs,
        resample=Image.Resampling.BICUBIC,
        fillcolor=(0, 0, 0, 0),
    )

    normal_result = normal_blend(photo, full_size_overlay)
    linear_light_result = linear_light_blend(photo, full_size_overlay)

    normal_result.save(NORMAL_OUTPUT_PATH, quality=95)
    linear_light_result.save(LINEAR_LIGHT_OUTPUT_PATH, quality=95)
    print(f"Saved {NORMAL_OUTPUT_PATH}")
    print(f"Saved {LINEAR_LIGHT_OUTPUT_PATH}")


if __name__ == "__main__":
    main()
