"""
Synthetic Blot Generator — GLIGEN + IP-Adapter
================================================
Generates new western blot images where:
  - GLIGEN takes COCO bounding boxes natively (no mask rendering needed)
    and places bands at those exact positions
  - IP-Adapter takes a reference blot image and injects its visual style
    so the output actually looks like a real western blot

Also provides a ControlNet + IP-Adapter fallback for environments
where GLIGEN isn't available.

Requirements:
    pip install diffusers>=0.25.0 transformers accelerate torch pillow numpy safetensors

Usage:
    from synthetic_blot_generator import COCO_DATASET, generate_synthetic_dataset
    synth = generate_synthetic_dataset(
        coco_dict=COCO_DATASET,
        reference_image_path="IMG_2270.jpeg",
        n_samples=10,
    )
"""

import json
import copy
import random
from pathlib import Path
from typing import Optional

import numpy as np
from PIL import Image, ImageDraw, ImageFilter
import torch


# ═══════════════════════════════════════════════════════════════════════════════
# 1. COCO ANNOTATIONS
# ═══════════════════════════════════════════════════════════════════════════════

COCO_DATASET = {
    "info": {
        "description": "Western Blot Band Detection",
        "version": "1.0",
        "year": 2025,
    },
    "images": [
        {
            "id": 1,
            "file_name": "IMG_2270.jpeg",
            "width": 829,
            "height": 1096,
        }
    ],
    "annotations": [
        {"id": 1,  "image_id": 1, "category_id": 1, "bbox": [52, 118, 62, 28],  "score": 0.85},
        {"id": 2,  "image_id": 1, "category_id": 1, "bbox": [247, 295, 58, 30], "score": 0.52},
        {"id": 3,  "image_id": 1, "category_id": 1, "bbox": [463, 280, 72, 35], "score": 0.65},
        {"id": 4,  "image_id": 1, "category_id": 1, "bbox": [225, 410, 100, 50], "score": 0.80},
        {"id": 5,  "image_id": 1, "category_id": 1, "bbox": [355, 405, 95, 50],  "score": 0.82},
        {"id": 6,  "image_id": 1, "category_id": 1, "bbox": [478, 400, 110, 55], "score": 0.83},
        {"id": 7,  "image_id": 1, "category_id": 1, "bbox": [630, 410, 110, 45], "score": 0.91},
        {"id": 8,  "image_id": 1, "category_id": 1, "bbox": [638, 420, 80, 40],  "score": 0.73},
        {"id": 9,  "image_id": 1, "category_id": 1, "bbox": [365, 480, 85, 35],  "score": 0.76},
        {"id": 10, "image_id": 1, "category_id": 1, "bbox": [485, 490, 80, 40],  "score": 0.63},
        {"id": 11, "image_id": 1, "category_id": 1, "bbox": [635, 465, 85, 40],  "score": 0.87},
        {"id": 12, "image_id": 1, "category_id": 1, "bbox": [65, 560, 100, 50],  "score": 0.95},
        {"id": 13, "image_id": 1, "category_id": 1, "bbox": [220, 555, 110, 55], "score": 0.91},
        {"id": 14, "image_id": 1, "category_id": 1, "bbox": [355, 555, 100, 50], "score": 0.87},
        {"id": 15, "image_id": 1, "category_id": 1, "bbox": [478, 550, 120, 55], "score": 0.89},
        {"id": 16, "image_id": 1, "category_id": 1, "bbox": [630, 550, 110, 55], "score": 0.96},
        {"id": 17, "image_id": 1, "category_id": 1, "bbox": [230, 720, 75, 30],  "score": 0.94},
        {"id": 18, "image_id": 1, "category_id": 1, "bbox": [365, 720, 70, 30],  "score": 0.91},
        {"id": 19, "image_id": 1, "category_id": 1, "bbox": [490, 722, 75, 30],  "score": 0.84},
        {"id": 20, "image_id": 1, "category_id": 1, "bbox": [635, 725, 65, 28],  "score": 0.79},
    ],
    "categories": [
        {"id": 1, "name": "band", "supercategory": "protein"},
    ],
}


# ═══════════════════════════════════════════════════════════════════════════════
# 2. COCO → GLIGEN FORMAT CONVERSION
# ═══════════════════════════════════════════════════════════════════════════════

def coco_to_gligen(
    annotations: list[dict],
    image_width: int,
    image_height: int,
    phrase: str = "dark protein band",
    max_boxes: int = 30,
) -> tuple[list[str], list[list[float]]]:
    """
    Convert COCO annotations to GLIGEN's native input format.

    COCO bbox: [x, y, w, h] in pixels (top-left origin)
    GLIGEN box: [x1, y1, x2, y2] normalized to [0, 1]

    Parameters
    ----------
    annotations : list[dict]
        COCO-format annotations with 'bbox' field.
    image_width, image_height : int
        Image dimensions for normalization.
    phrase : str
        Text phrase per box. Can vary per box by passing a dict mapping
        category_id → phrase.
    max_boxes : int
        GLIGEN has a practical limit on simultaneous boxes.

    Returns
    -------
    (phrases, boxes) ready to pass directly to GLIGEN pipeline.
    """
    phrases = []
    boxes = []

    for ann in annotations[:max_boxes]:
        x, y, w, h = ann["bbox"]
        score = ann.get("score", 1.0)

        # Normalize to [0, 1]
        x1 = x / image_width
        y1 = y / image_height
        x2 = (x + w) / image_width
        y2 = (y + h) / image_height

        # Clamp
        x1 = max(0.0, min(1.0, x1))
        y1 = max(0.0, min(1.0, y1))
        x2 = max(0.0, min(1.0, x2))
        y2 = max(0.0, min(1.0, y2))

        if x2 <= x1 or y2 <= y1:
            continue

        # Vary phrase based on band intensity
        if score >= 0.8:
            p = "dark strong protein band"
        elif score >= 0.6:
            p = "protein band"
        else:
            p = "faint light protein band"

        phrases.append(p)
        boxes.append([x1, y1, x2, y2])

    return phrases, boxes


def coco_to_yolo(
    annotations: list[dict],
    image_width: int,
    image_height: int,
) -> list[tuple[int, float, float, float, float]]:
    """
    Convert COCO annotations to YOLO format (for reference / export).

    Returns list of (class_id, cx, cy, w, h) with normalized coords.
    """
    yolo = []
    for ann in annotations:
        x, y, w, h = ann["bbox"]
        cx = (x + w / 2) / image_width
        cy = (y + h / 2) / image_height
        nw = w / image_width
        nh = h / image_height
        yolo.append((ann.get("category_id", 0), cx, cy, nw, nh))
    return yolo


# ═══════════════════════════════════════════════════════════════════════════════
# 3. GLIGEN + IP-ADAPTER PIPELINE
# ═══════════════════════════════════════════════════════════════════════════════

def load_gligen_ip_adapter_pipeline(
    gligen_model_id: str = "masterful/gligen-1-4-generation-text-box",
    ip_adapter_model_id: str = "h94/IP-Adapter",
    ip_adapter_weight_name: str = "models/ip-adapter_sd15.bin",
    ip_adapter_subfolder: str = ".",
    device: str = "cuda",
) -> "StableDiffusionGLIGENPipeline":
    """
    Load the combined GLIGEN + IP-Adapter pipeline.

    - GLIGEN handles native bbox → spatial grounding
    - IP-Adapter handles reference image → style conditioning

    The pipeline is loaded once and reused across samples.
    """
    from diffusers import StableDiffusionGLIGENPipeline

    dtype = torch.float16 if device == "cuda" else torch.float32

    pipe = StableDiffusionGLIGENPipeline.from_pretrained(
        gligen_model_id,
        torch_dtype=dtype,
        safety_checker=None,
        requires_safety_checker=False,
    )

    # Load IP-Adapter on top of the GLIGEN pipeline
    pipe.load_ip_adapter(
        ip_adapter_model_id,
        subfolder=ip_adapter_subfolder,
        weight_name=ip_adapter_weight_name,
    )

    pipe = pipe.to(device)

    if device == "cuda":
        pipe.enable_attention_slicing()
        try:
            pipe.enable_xformers_memory_efficient_attention()
        except Exception:
            pass

    return pipe


def generate_with_gligen_ip_adapter(
    pipe,
    annotations: list[dict],
    reference_image: Image.Image,
    image_width: int,
    image_height: int,
    prompt: str = (
        "a grayscale western blot gel electrophoresis image, dark protein "
        "bands on light background, scientific photograph, high contrast, sharp"
    ),
    negative_prompt: str = (
        "color, colorful, text, labels, bounding boxes, cartoon, drawing, "
        "blurry, low quality, watermark, oversaturated"
    ),
    ip_adapter_scale: float = 0.6,
    gligen_scheduled_sampling_beta: float = 0.3,
    guidance_scale: float = 7.5,
    num_inference_steps: int = 50,
    seed: Optional[int] = None,
    device: str = "cuda",
) -> Image.Image:
    """
    Generate a single synthetic blot image.

    - Annotations (COCO bboxes) → tell GLIGEN where to place bands
    - Reference image → IP-Adapter injects its visual style

    Parameters
    ----------
    pipe : StableDiffusionGLIGENPipeline
        Pre-loaded pipeline with IP-Adapter.
    annotations : list[dict]
        COCO-format bounding boxes defining band positions.
    reference_image : PIL.Image
        Example blot image for style conditioning.
    image_width, image_height : int
        Target output dimensions.
    ip_adapter_scale : float
        How strongly the reference image influences the output (0–1).
        0.4–0.7 recommended: enough to learn "what a blot looks like"
        without copying the reference exactly.
    gligen_scheduled_sampling_beta : float
        GLIGEN grounding strength (0–1). Higher = bands more strictly
        placed at bbox positions. 0.3–0.6 recommended.
    seed : int or None
        For reproducibility.

    Returns
    -------
    PIL.Image – generated blot image.
    """
    # Convert COCO → GLIGEN format
    phrases, boxes = coco_to_gligen(annotations, image_width, image_height)

    if not boxes:
        raise ValueError("No valid boxes after conversion")

    # Set IP-Adapter influence
    pipe.set_ip_adapter_scale(ip_adapter_scale)

    # Prepare reference image (IP-Adapter expects RGB)
    ref_rgb = reference_image.convert("RGB").resize((256, 256), Image.LANCZOS)

    # GLIGEN generates at 512×512 internally, we resize after
    generator = torch.Generator(device=device)
    if seed is not None:
        generator.manual_seed(seed)

    result = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        gligen_phrases=phrases,
        gligen_boxes=boxes,
        gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
        ip_adapter_image=ref_rgb,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=generator,
        width=512,
        height=512,
    )
    synth = result.images[0]

    # Resize to target dimensions
    if synth.size != (image_width, image_height):
        synth = synth.resize((image_width, image_height), Image.LANCZOS)

    return synth


# ═══════════════════════════════════════════════════════════════════════════════
# 4. CONTROLNET + IP-ADAPTER FALLBACK
# ═══════════════════════════════════════════════════════════════════════════════

def annotations_to_scribble_map(
    annotations: list[dict],
    width: int,
    height: int,
    blur_radius: int = 2,
) -> Image.Image:
    """Render annotations as a scribble map for ControlNet."""
    scribble = Image.new("L", (width, height), 0)
    draw = ImageDraw.Draw(scribble)
    for ann in annotations:
        x, y, w, h = ann["bbox"]
        draw.rectangle([x, y, x + w, y + h], fill=255)
    if blur_radius > 0:
        scribble = scribble.filter(ImageFilter.GaussianBlur(radius=blur_radius))
    return scribble


def load_controlnet_ip_adapter_pipeline(
    sd_model_id: str = "runwayml/stable-diffusion-v1-5",
    controlnet_model_id: str = "lllyasviel/control_v11p_sd15_scribble",
    ip_adapter_model_id: str = "h94/IP-Adapter",
    ip_adapter_weight_name: str = "models/ip-adapter_sd15.bin",
    ip_adapter_subfolder: str = ".",
    device: str = "cuda",
):
    """Load ControlNet + IP-Adapter pipeline (fallback for GLIGEN)."""
    from diffusers import (
        StableDiffusionControlNetPipeline,
        ControlNetModel,
        UniPCMultistepScheduler,
    )

    dtype = torch.float16 if device == "cuda" else torch.float32

    controlnet = ControlNetModel.from_pretrained(
        controlnet_model_id, torch_dtype=dtype,
    )
    pipe = StableDiffusionControlNetPipeline.from_pretrained(
        sd_model_id,
        controlnet=controlnet,
        torch_dtype=dtype,
        safety_checker=None,
        requires_safety_checker=False,
    )
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

    pipe.load_ip_adapter(
        ip_adapter_model_id,
        subfolder=ip_adapter_subfolder,
        weight_name=ip_adapter_weight_name,
    )

    pipe = pipe.to(device)
    if device == "cuda":
        pipe.enable_attention_slicing()
        try:
            pipe.enable_xformers_memory_efficient_attention()
        except Exception:
            pass

    return pipe


def generate_with_controlnet_ip_adapter(
    pipe,
    annotations: list[dict],
    reference_image: Image.Image,
    image_width: int,
    image_height: int,
    prompt: str = (
        "a grayscale western blot gel electrophoresis image, dark protein "
        "bands on light background, scientific photograph, high contrast, sharp"
    ),
    negative_prompt: str = (
        "color, colorful, text, labels, bounding boxes, cartoon, drawing, "
        "blurry, low quality, watermark, oversaturated"
    ),
    ip_adapter_scale: float = 0.6,
    controlnet_conditioning_scale: float = 1.0,
    guidance_scale: float = 7.5,
    num_inference_steps: int = 50,
    seed: Optional[int] = None,
    device: str = "cuda",
) -> Image.Image:
    """Generate using ControlNet (scribble map) + IP-Adapter (reference)."""
    pipe.set_ip_adapter_scale(ip_adapter_scale)

    control_image = annotations_to_scribble_map(
        annotations, image_width, image_height,
    )
    # Resize for SD
    gen_w = (image_width // 8) * 8
    gen_h = (image_height // 8) * 8
    control_rgb = control_image.resize((gen_w, gen_h), Image.LANCZOS).convert("RGB")
    ref_rgb = reference_image.convert("RGB").resize((256, 256), Image.LANCZOS)

    generator = torch.Generator(device=device)
    if seed is not None:
        generator.manual_seed(seed)

    result = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        image=control_rgb,
        ip_adapter_image=ref_rgb,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=generator,
    )
    synth = result.images[0]

    if synth.size != (image_width, image_height):
        synth = synth.resize((image_width, image_height), Image.LANCZOS)
    return synth


# ═══════════════════════════════════════════════════════════════════════════════
# 5. ANNOTATION PERTURBATION (for layout variation)
# ═══════════════════════════════════════════════════════════════════════════════

def perturb_annotations(
    annotations: list[dict],
    image_width: int,
    image_height: int,
    position_jitter: float = 0.1,
    size_jitter: float = 0.15,
    drop_prob: float = 0.1,
) -> list[dict]:
    """
    Create a varied layout from seed annotations by randomly shifting,
    scaling, and dropping bands. Useful for generating diverse training data.
    """
    new_anns = []
    for ann in annotations:
        if random.random() < drop_prob:
            continue

        x, y, w, h = ann["bbox"]
        dx = random.gauss(0, position_jitter * w)
        dy = random.gauss(0, position_jitter * h)
        sw = random.uniform(1 - size_jitter, 1 + size_jitter)
        sh = random.uniform(1 - size_jitter, 1 + size_jitter)
        new_score = min(1.0, max(0.3, ann.get("score", 1.0) + random.gauss(0, 0.1)))

        nx = max(0, min(image_width - w * sw, x + dx))
        ny = max(0, min(image_height - h * sh, y + dy))

        new_ann = copy.deepcopy(ann)
        new_ann["bbox"] = [round(nx, 1), round(ny, 1),
                           round(w * sw, 1), round(h * sh, 1)]
        new_ann["score"] = round(new_score, 2)
        new_anns.append(new_ann)

    return new_anns


# ═══════════════════════════════════════════════════════════════════════════════
# 6. MAIN DATASET GENERATION PIPELINE
# ═══════════════════════════════════════════════════════════════════════════════

def generate_synthetic_dataset(
    coco_dict: dict,
    reference_image_path: str,
    output_dir: str = "synthetic_dataset",
    n_samples: int = 10,
    method: str = "gligen",
    perturb_layout: bool = False,
    # ── GLIGEN params ──
    gligen_model_id: str = "masterful/gligen-1-4-generation-text-box",
    gligen_scheduled_sampling_beta: float = 0.3,
    # ── ControlNet params (fallback) ──
    sd_model_id: str = "runwayml/stable-diffusion-v1-5",
    controlnet_model_id: str = "lllyasviel/control_v11p_sd15_scribble",
    controlnet_conditioning_scale: float = 1.0,
    # ── IP-Adapter params ──
    ip_adapter_model_id: str = "h94/IP-Adapter",
    ip_adapter_weight_name: str = "models/ip-adapter_sd15.bin",
    ip_adapter_subfolder: str = ".",
    ip_adapter_scale: float = 0.6,
    # ── Shared diffusion params ──
    prompt: Optional[str] = None,
    negative_prompt: Optional[str] = None,
    guidance_scale: float = 7.5,
    num_inference_steps: int = 50,
    device: str = "cuda",
) -> dict:
    """
    Generate a synthetic dataset of blot images.

    COCO annotations define WHERE bands appear (fed to GLIGEN as native
    bboxes, or rendered as scribble map for ControlNet).

    Reference image defines WHAT the output looks like (fed to IP-Adapter
    so the model learns western blot appearance from your example).

    Parameters
    ----------
    coco_dict : dict
        COCO annotation dict. Bounding boxes define band positions.
    reference_image_path : str
        Path to a real blot image. IP-Adapter uses this to learn the
        visual style (gel texture, band appearance, contrast, etc).
    method : str
        'gligen' — native bbox input (preferred)
        'controlnet' — scribble map fallback
    perturb_layout : bool
        If True, randomly jitter the annotation layout per sample for
        more diverse training data.
    ip_adapter_scale : float
        Reference image influence (0–1). Higher = output looks more
        like the reference. 0.4–0.7 recommended.
    gligen_scheduled_sampling_beta : float
        GLIGEN grounding strength (0–1). Higher = stricter bbox adherence.
        0.3–0.6 recommended.

    Returns
    -------
    dict – COCO annotation dict for the synthetic dataset.
    """
    out = Path(output_dir)
    out.mkdir(parents=True, exist_ok=True)
    (out / "images").mkdir(exist_ok=True)

    # ── Load reference image ──
    reference_image = Image.open(reference_image_path)

    # ── Extract source annotations ──
    img_info = coco_dict["images"][0]
    W, H = img_info["width"], img_info["height"]
    src_annotations = [a for a in coco_dict["annotations"]
                       if a["image_id"] == img_info["id"]]

    default_prompt = (
        "a grayscale western blot gel electrophoresis image, dark protein "
        "bands on light background, scientific photograph, high contrast, sharp"
    )
    default_neg = (
        "color, colorful, text, labels, bounding boxes, cartoon, drawing, "
        "blurry, low quality, watermark, oversaturated"
    )
    prompt = prompt or default_prompt
    negative_prompt = negative_prompt or default_neg

    # ── Load pipeline once (reused across all samples) ──
    print(f"Loading {method} + IP-Adapter pipeline...")
    if method == "gligen":
        pipe = load_gligen_ip_adapter_pipeline(
            gligen_model_id=gligen_model_id,
            ip_adapter_model_id=ip_adapter_model_id,
            ip_adapter_weight_name=ip_adapter_weight_name,
            ip_adapter_subfolder=ip_adapter_subfolder,
            device=device,
        )
    else:
        pipe = load_controlnet_ip_adapter_pipeline(
            sd_model_id=sd_model_id,
            controlnet_model_id=controlnet_model_id,
            ip_adapter_model_id=ip_adapter_model_id,
            ip_adapter_weight_name=ip_adapter_weight_name,
            ip_adapter_subfolder=ip_adapter_subfolder,
            device=device,
        )
    print("Pipeline loaded ✓")

    # ── Output COCO dict ──
    synth_coco = {
        "info": {
            **coco_dict["info"],
            "description": f"Synthetic ({method}+ip-adapter) " + coco_dict["info"]["description"],
        },
        "images": [],
        "annotations": [],
        "categories": copy.deepcopy(coco_dict["categories"]),
    }
    ann_id_counter = 1

    # ── Generate samples ──
    for i in range(n_samples):
        seed = random.randint(0, 2**32 - 1)

        # Optionally perturb layout
        if perturb_layout:
            sample_anns = perturb_annotations(src_annotations, W, H)
        else:
            sample_anns = copy.deepcopy(src_annotations)

        print(f"\n[{i+1}/{n_samples}] seed={seed}  boxes={len(sample_anns)}")

        if method == "gligen":
            synth_img = generate_with_gligen_ip_adapter(
                pipe=pipe,
                annotations=sample_anns,
                reference_image=reference_image,
                image_width=W,
                image_height=H,
                prompt=prompt,
                negative_prompt=negative_prompt,
                ip_adapter_scale=ip_adapter_scale,
                gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps,
                seed=seed,
                device=device,
            )
        else:
            synth_img = generate_with_controlnet_ip_adapter(
                pipe=pipe,
                annotations=sample_anns,
                reference_image=reference_image,
                image_width=W,
                image_height=H,
                prompt=prompt,
                negative_prompt=negative_prompt,
                ip_adapter_scale=ip_adapter_scale,
                controlnet_conditioning_scale=controlnet_conditioning_scale,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps,
                seed=seed,
                device=device,
            )

        # ── Save ──
        fname = f"synth_{i:04d}.png"
        synth_img.save(out / "images" / fname)

        img_id = i + 1000
        synth_coco["images"].append({
            "id": img_id,
            "file_name": fname,
            "width": W,
            "height": H,
        })

        for ann in sample_anns:
            new_ann = copy.deepcopy(ann)
            new_ann["id"] = ann_id_counter
            new_ann["image_id"] = img_id
            bw, bh = new_ann["bbox"][2], new_ann["bbox"][3]
            new_ann["area"] = round(bw * bh, 2)
            new_ann["iscrowd"] = 0
            synth_coco["annotations"].append(new_ann)
            ann_id_counter += 1

        print(f"  ✓ {fname}  ({len(sample_anns)} bands)")

    # ── Write annotation file ──
    ann_path = out / "annotations.json"
    with open(ann_path, "w") as f:
        json.dump(synth_coco, f, indent=2)

    # ── Save reference for reproducibility ──
    reference_image.save(out / "reference.png")

    print(f"\n✅ Done: {n_samples} images → {out}")
    print(f"   Annotations: {ann_path}")
    return synth_coco


# ═══════════════════════════════════════════════════════════════════════════════
# 7. VISUALIZATION
# ═══════════════════════════════════════════════════════════════════════════════

def visualize_annotations(image: Image.Image, annotations: list[dict]) -> Image.Image:
    """Draw bounding boxes on an image."""
    vis = image.copy().convert("RGB")
    draw = ImageDraw.Draw(vis)
    for ann in annotations:
        x, y, w, h = ann["bbox"]
        score = ann.get("score", 1.0)
        color = "lime" if score >= 0.7 else "yellow"
        draw.rectangle([x, y, x + w, y + h], outline=color, width=2)
        draw.text((x, y - 12), f"{score:.2f}", fill=color)
    return vis


def visualize_gligen_inputs(
    annotations: list[dict],
    image_width: int,
    image_height: int,
):
    """Show what GLIGEN receives: normalized boxes overlaid on a canvas."""
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches

    phrases, boxes = coco_to_gligen(annotations, image_width, image_height)

    fig, ax = plt.subplots(1, 1, figsize=(8, 10))
    ax.set_xlim(0, 1)
    ax.set_ylim(1, 0)  # flip y so origin is top-left
    ax.set_aspect(image_width / image_height)
    ax.set_facecolor("#f0f0f0")
    ax.set_title(f"GLIGEN input: {len(boxes)} boxes (normalized coords)")

    for phrase, (x1, y1, x2, y2) in zip(phrases, boxes):
        rect = patches.Rectangle(
            (x1, y1), x2 - x1, y2 - y1,
            linewidth=2, edgecolor="lime", facecolor="black", alpha=0.6,
        )
        ax.add_patch(rect)
        ax.text(x1, y1 - 0.005, phrase, fontsize=6, color="lime",
                verticalalignment="bottom")

    plt.tight_layout()
    plt.show()


def show_comparison(
    synth_dir: str,
    reference_path: str,
    n: int = 4,
):
    """Display reference vs generated samples side-by-side."""
    import matplotlib.pyplot as plt

    synth_path = Path(synth_dir)
    with open(synth_path / "annotations.json") as f:
        synth_coco = json.load(f)

    cols = min(n + 1, 5)
    fig, axes = plt.subplots(1, cols, figsize=(5 * cols, 7))
    if cols == 1:
        axes = [axes]

    ref = Image.open(reference_path)
    ref_anns = [a for a in COCO_DATASET["annotations"]]
    axes[0].imshow(visualize_annotations(ref, ref_anns))
    axes[0].set_title("Reference\n(IP-Adapter input)", fontweight="bold")
    axes[0].axis("off")

    for i, img_info in enumerate(synth_coco["images"][:n]):
        if i + 1 >= cols:
            break
        img = Image.open(synth_path / "images" / img_info["file_name"])
        anns = [a for a in synth_coco["annotations"]
                if a["image_id"] == img_info["id"]]
        axes[i + 1].imshow(visualize_annotations(img, anns))
        axes[i + 1].set_title(f"Generated #{i+1}\n({len(anns)} bands)")
        axes[i + 1].axis("off")

    plt.suptitle("Reference image → style  |  COCO annotations → band positions",
                 fontsize=11, y=1.02)
    plt.tight_layout()
    plt.show()
