"""
Synthetic Blot Generator — Layout-Driven Generation
=====================================================
Given COCO bounding-box annotations, generates NEW blot images with bands
appearing at the annotated locations. The annotations define the spatial
layout; the diffusion model fills in realistic appearance.

Two generation strategies:
  1. ControlNet  (preferred) — layout mask as spatial control signal
  2. Template img2img (fallback) — render a synthetic template, then
     img2img it into a realistic blot

Usage from Jupyter:
    from synthetic_blot_generator import *
    synth_coco = generate_synthetic_dataset(...)

Requirements:
    pip install diffusers transformers accelerate torch pillow numpy safetensors controlnet-aux
"""

# %% ── Imports ──────────────────────────────────────────────────────────────────
import json
import copy
import random
from pathlib import Path
from typing import Optional

import numpy as np
from PIL import Image, ImageDraw, ImageFilter
import torch


# %% ── 1. COCO annotations ─────────────────────────────────────────────────────

COCO_DATASET = {
    "info": {
        "description": "Western Blot Band Detection",
        "version": "1.0",
        "year": 2025,
    },
    "images": [
        {
            "id": 1,
            "file_name": "IMG_2270.jpeg",
            "width": 829,
            "height": 1096,
        }
    ],
    "annotations": [
        {"id": 1,  "image_id": 1, "category_id": 1, "bbox": [52, 118, 62, 28],  "score": 0.85},
        {"id": 2,  "image_id": 1, "category_id": 1, "bbox": [247, 295, 58, 30], "score": 0.52},
        {"id": 3,  "image_id": 1, "category_id": 1, "bbox": [463, 280, 72, 35], "score": 0.65},
        {"id": 4,  "image_id": 1, "category_id": 1, "bbox": [225, 410, 100, 50], "score": 0.80},
        {"id": 5,  "image_id": 1, "category_id": 1, "bbox": [355, 405, 95, 50],  "score": 0.82},
        {"id": 6,  "image_id": 1, "category_id": 1, "bbox": [478, 400, 110, 55], "score": 0.83},
        {"id": 7,  "image_id": 1, "category_id": 1, "bbox": [630, 410, 110, 45], "score": 0.91},
        {"id": 8,  "image_id": 1, "category_id": 1, "bbox": [638, 420, 80, 40],  "score": 0.73},
        {"id": 9,  "image_id": 1, "category_id": 1, "bbox": [365, 480, 85, 35],  "score": 0.76},
        {"id": 10, "image_id": 1, "category_id": 1, "bbox": [485, 490, 80, 40],  "score": 0.63},
        {"id": 11, "image_id": 1, "category_id": 1, "bbox": [635, 465, 85, 40],  "score": 0.87},
        {"id": 12, "image_id": 1, "category_id": 1, "bbox": [65, 560, 100, 50],  "score": 0.95},
        {"id": 13, "image_id": 1, "category_id": 1, "bbox": [220, 555, 110, 55], "score": 0.91},
        {"id": 14, "image_id": 1, "category_id": 1, "bbox": [355, 555, 100, 50], "score": 0.87},
        {"id": 15, "image_id": 1, "category_id": 1, "bbox": [478, 550, 120, 55], "score": 0.89},
        {"id": 16, "image_id": 1, "category_id": 1, "bbox": [630, 550, 110, 55], "score": 0.96},
        {"id": 17, "image_id": 1, "category_id": 1, "bbox": [230, 720, 75, 30],  "score": 0.94},
        {"id": 18, "image_id": 1, "category_id": 1, "bbox": [365, 720, 70, 30],  "score": 0.91},
        {"id": 19, "image_id": 1, "category_id": 1, "bbox": [490, 722, 75, 30],  "score": 0.84},
        {"id": 20, "image_id": 1, "category_id": 1, "bbox": [635, 725, 65, 28],  "score": 0.79},
    ],
    "categories": [
        {"id": 1, "name": "band", "supercategory": "protein"},
    ],
}


# %% ── 2. Render annotations into conditioning images ──────────────────────────

def annotations_to_control_image(
    annotations: list[dict],
    width: int,
    height: int,
    blur_radius: int = 3,
) -> Image.Image:
    """
    Render COCO bounding boxes into a ControlNet-style control image.
    White background, dark filled rectangles where bands should appear.
    Intensity is proportional to the annotation score.

    This is fed directly to ControlNet as the spatial control signal.
    """
    # White background (light gel), dark bands
    control = Image.new("L", (width, height), 255)
    draw = ImageDraw.Draw(control)
    for ann in annotations:
        x, y, w, h = ann["bbox"]
        score = ann.get("score", 1.0)
        # Darker = stronger band. Score 1.0 → pixel 0 (black), score 0.5 → pixel 128
        intensity = int((1.0 - score) * 255)
        draw.rectangle([x, y, x + w, y + h], fill=intensity)
    if blur_radius > 0:
        control = control.filter(ImageFilter.GaussianBlur(radius=blur_radius))
    return control


def annotations_to_scribble_map(
    annotations: list[dict],
    width: int,
    height: int,
) -> Image.Image:
    """
    Render annotations as a scribble/edge map for ControlNet-scribble.
    Black background, white outlines where bands should appear.
    """
    scribble = Image.new("L", (width, height), 0)
    draw = ImageDraw.Draw(scribble)
    for ann in annotations:
        x, y, w, h = ann["bbox"]
        # Draw filled white rectangles (scribble model expects white=structure)
        draw.rectangle([x, y, x + w, y + h], fill=255)
    scribble = scribble.filter(ImageFilter.GaussianBlur(radius=2))
    return scribble


def annotations_to_template(
    annotations: list[dict],
    width: int,
    height: int,
    reference_image: Optional[Image.Image] = None,
) -> Image.Image:
    """
    Render a synthetic blot template from annotations.
    Used as the init image for img2img (fallback when ControlNet unavailable).

    If a reference_image is provided, band intensities are sampled from
    the reference at each bbox location for more realistic templates.
    """
    # Light grey background with slight vertical gradient (like a real gel)
    bg_top = random.randint(180, 210)
    bg_bot = random.randint(190, 220)
    arr = np.zeros((height, width), dtype=np.float32)
    for row in range(height):
        arr[row, :] = bg_top + (bg_bot - bg_top) * (row / height)

    # Add subtle vertical lane stripes
    n_lanes = 6
    lane_width = width // (n_lanes + 1)
    for lane in range(n_lanes):
        cx = (lane + 1) * lane_width
        for col in range(width):
            dist = abs(col - cx) / (lane_width * 0.4)
            if dist < 1.0:
                arr[:, col] -= 8 * (1.0 - dist)  # subtle darkening

    # Draw bands
    for ann in annotations:
        x, y, w, h = ann["bbox"]
        score = ann.get("score", 1.0)

        if reference_image is not None:
            # Sample mean intensity from reference at this bbox
            ref_arr = np.array(reference_image.convert("L"), dtype=np.float32)
            x1c = max(0, int(x))
            y1c = max(0, int(y))
            x2c = min(width, int(x + w))
            y2c = min(height, int(y + h))
            if x2c > x1c and y2c > y1c:
                band_intensity = ref_arr[y1c:y2c, x1c:x2c].mean()
            else:
                band_intensity = 40
        else:
            # Synthetic intensity: strong bands are very dark
            band_intensity = 20 + (1.0 - score) * 100

        # Create a soft elliptical band shape
        yy, xx = np.ogrid[0:int(h), 0:int(w)]
        cy_b, cx_b = h / 2, w / 2
        # Elliptical falloff
        dist = ((xx - cx_b) / (cx_b + 1e-6)) ** 2 + ((yy - cy_b) / (cy_b + 1e-6)) ** 2
        mask = np.clip(1.0 - dist, 0, 1)
        # Soften
        mask = mask ** 0.6

        x1, y1 = int(x), int(y)
        x2, y2 = x1 + int(w), y1 + int(h)
        x1c, y1c = max(0, x1), max(0, y1)
        x2c, y2c = min(width, x2), min(height, y2)
        mw, mh = x2c - x1c, y2c - y1c
        if mw > 0 and mh > 0:
            m = mask[:mh, :mw]
            region = arr[y1c:y2c, x1c:x2c]
            arr[y1c:y2c, x1c:x2c] = region * (1 - m) + band_intensity * m

    # Add noise
    noise = np.random.normal(0, 3, arr.shape)
    arr = np.clip(arr + noise, 0, 255).astype(np.uint8)
    return Image.fromarray(arr, mode="L")


def visualize_annotations(image: Image.Image, annotations: list[dict]) -> Image.Image:
    """Draw bounding boxes on an image for visual QA."""
    vis = image.copy().convert("RGB")
    draw = ImageDraw.Draw(vis)
    for ann in annotations:
        x, y, w, h = ann["bbox"]
        score = ann.get("score", 1.0)
        color = "lime" if score >= 0.7 else "yellow"
        draw.rectangle([x, y, x + w, y + h], outline=color, width=2)
        draw.text((x, y - 12), f"{score:.2f}", fill=color)
    return vis


# %% ── 3. ControlNet generation (preferred) ──────────────────────────────────

def generate_with_controlnet(
    annotations: list[dict],
    width: int,
    height: int,
    prompt: str = (
        "a high quality grayscale western blot gel electrophoresis image, "
        "dark protein bands on light background, sharp bands, "
        "scientific laboratory photograph, high contrast"
    ),
    negative_prompt: str = (
        "color, text, labels, annotations, bounding boxes, cartoon, "
        "drawing, blurry, low quality, watermark"
    ),
    controlnet_conditioning_scale: float = 1.0,
    guidance_scale: float = 7.5,
    num_inference_steps: int = 50,
    seed: Optional[int] = None,
    device: str = "cuda",
    sd_model_id: str = "runwayml/stable-diffusion-v1-5",
    controlnet_model_id: str = "lllyasviel/control_v11p_sd15_scribble",
) -> Image.Image:
    """
    Generate a new blot image with bands at the locations defined by
    the COCO annotations, using ControlNet for spatial conditioning.

    The annotations are rendered into a scribble control image (white
    filled rectangles on black = "draw bands here"), which ControlNet
    uses to place bands in those exact positions.

    Parameters
    ----------
    annotations : list[dict]
        COCO-format annotations. Only 'bbox' and 'score' are used.
    width, height : int
        Output image dimensions.
    controlnet_conditioning_scale : float
        How strictly ControlNet follows the layout (0–2). Higher = stricter.
    guidance_scale : float
        Classifier-free guidance. Higher = more prompt-adherent.
    seed : int or None
        For reproducibility.

    Returns
    -------
    PIL.Image – generated blot image at (width, height).
    """
    from diffusers import (
        StableDiffusionControlNetPipeline,
        ControlNetModel,
        UniPCMultistepScheduler,
    )

    # ── Load ControlNet + SD pipeline ──
    controlnet = ControlNetModel.from_pretrained(
        controlnet_model_id,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    )
    pipe = StableDiffusionControlNetPipeline.from_pretrained(
        sd_model_id,
        controlnet=controlnet,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        safety_checker=None,
        requires_safety_checker=False,
    )
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
    pipe = pipe.to(device)

    if device == "cuda":
        pipe.enable_attention_slicing()
        try:
            pipe.enable_xformers_memory_efficient_attention()
        except Exception:
            pass

    # ── Build control image from annotations ──
    control_image = annotations_to_scribble_map(annotations, width, height)

    # SD needs dims divisible by 8
    gen_w = (width // 8) * 8
    gen_h = (height // 8) * 8
    control_resized = control_image.resize((gen_w, gen_h), Image.LANCZOS)
    # ControlNet expects RGB
    control_rgb = control_resized.convert("RGB")

    # ── Generate ──
    generator = torch.Generator(device=device)
    if seed is not None:
        generator.manual_seed(seed)

    result = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        image=control_rgb,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        generator=generator,
    )
    synth = result.images[0]

    # Resize to exact requested dimensions
    if synth.size != (width, height):
        synth = synth.resize((width, height), Image.LANCZOS)
    return synth


# %% ── 4. Template img2img generation (fallback) ─────────────────────────────

def generate_with_img2img(
    annotations: list[dict],
    width: int,
    height: int,
    reference_image: Optional[Image.Image] = None,
    prompt: str = (
        "a high quality grayscale western blot gel electrophoresis image, "
        "dark protein bands on light background, sharp bands, "
        "scientific laboratory photograph, high contrast"
    ),
    negative_prompt: str = (
        "color, text, labels, annotations, bounding boxes, cartoon, "
        "drawing, blurry, low quality, watermark"
    ),
    strength: float = 0.55,
    guidance_scale: float = 7.5,
    num_inference_steps: int = 50,
    seed: Optional[int] = None,
    device: str = "cuda",
    model_id: str = "stabilityai/stable-diffusion-2-1",
) -> Image.Image:
    """
    Fallback generation without ControlNet.

    Renders the annotations into a synthetic blot template (dark bands on
    light background at the bbox positions), then uses img2img to transform
    it into a realistic-looking blot. The band positions are preserved
    because strength is kept moderate (0.4–0.6).

    Parameters
    ----------
    annotations : list[dict]
        COCO-format annotations.
    width, height : int
        Output image dimensions.
    reference_image : PIL.Image or None
        If provided, band intensities in the template are sampled from
        the reference for more realistic seeding.
    strength : float
        img2img denoising strength. 0.4–0.6 keeps band positions intact
        while making the image look realistic.
    seed : int or None
        For reproducibility.

    Returns
    -------
    PIL.Image – generated blot image.
    """
    from diffusers import StableDiffusionImg2ImgPipeline, DDIMScheduler

    # ── Build template from annotations ──
    template = annotations_to_template(
        annotations, width, height, reference_image=reference_image,
    )

    # ── Load pipeline ──
    scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
    pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
        model_id,
        scheduler=scheduler,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        safety_checker=None,
        requires_safety_checker=False,
    )
    pipe = pipe.to(device)

    if device == "cuda":
        pipe.enable_attention_slicing()
        try:
            pipe.enable_xformers_memory_efficient_attention()
        except Exception:
            pass

    # ── Prepare ──
    gen_w = (width // 8) * 8
    gen_h = (height // 8) * 8
    init_image = template.convert("RGB").resize((gen_w, gen_h), Image.LANCZOS)

    generator = torch.Generator(device=device)
    if seed is not None:
        generator.manual_seed(seed)

    result = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        image=init_image,
        strength=strength,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        generator=generator,
    )
    synth = result.images[0]

    if synth.size != (width, height):
        synth = synth.resize((width, height), Image.LANCZOS)
    return synth


# %% ── 5. Dataset generation pipeline ────────────────────────────────────────

def generate_synthetic_dataset(
    coco_dict: dict,
    output_dir: str = "synthetic_dataset",
    n_samples: int = 10,
    method: str = "controlnet",
    reference_image_path: Optional[str] = None,
    # ── ControlNet params ──
    controlnet_conditioning_scale: float = 1.0,
    sd_model_id: str = "runwayml/stable-diffusion-v1-5",
    controlnet_model_id: str = "lllyasviel/control_v11p_sd15_scribble",
    # ── img2img params ──
    img2img_model_id: str = "stabilityai/stable-diffusion-2-1",
    img2img_strength: float = 0.55,
    # ── Shared params ──
    guidance_scale: float = 7.5,
    num_inference_steps: int = 50,
    prompt: Optional[str] = None,
    negative_prompt: Optional[str] = None,
    device: str = "cuda",
) -> dict:
    """
    Generate a synthetic dataset of blot images with bands placed at the
    locations defined by the COCO annotations.

    Parameters
    ----------
    coco_dict : dict
        COCO-format annotation dict. Annotations define WHERE bands appear.
    output_dir : str
        Output directory for images/ and annotations.json.
    n_samples : int
        Number of synthetic images to generate.
    method : str
        'controlnet' (preferred) or 'img2img' (fallback).
    reference_image_path : str or None
        Optional reference image for style. Used by img2img to sample
        band intensities for the template.
    controlnet_conditioning_scale : float
        ControlNet spatial strictness (0–2). Higher = bands more precisely
        at the annotated positions.
    guidance_scale : float
        Prompt adherence.
    num_inference_steps : int
        Denoising steps.
    device : str
        'cuda' or 'cpu'.

    Returns
    -------
    dict – COCO annotation dict for the synthetic dataset.
    """
    out = Path(output_dir)
    out.mkdir(parents=True, exist_ok=True)
    (out / "images").mkdir(exist_ok=True)

    img_info = coco_dict["images"][0]
    W, H = img_info["width"], img_info["height"]
    src_annotations = [a for a in coco_dict["annotations"]
                       if a["image_id"] == img_info["id"]]

    ref_img = None
    if reference_image_path:
        ref_img = Image.open(reference_image_path)

    default_prompt = (
        "a high quality grayscale western blot gel electrophoresis image, "
        "dark protein bands on light background, sharp bands, "
        "scientific laboratory photograph, high contrast"
    )
    default_neg = (
        "color, text, labels, annotations, bounding boxes, cartoon, "
        "drawing, blurry, low quality, watermark"
    )
    prompt = prompt or default_prompt
    negative_prompt = negative_prompt or default_neg

    synth_coco = {
        "info": {**coco_dict["info"], "description": "Synthetic " + coco_dict["info"]["description"]},
        "images": [],
        "annotations": [],
        "categories": copy.deepcopy(coco_dict["categories"]),
    }

    # Save the control/template image for inspection
    if method == "controlnet":
        ctrl = annotations_to_scribble_map(src_annotations, W, H)
        ctrl.save(out / "control_image.png")
        print(f"  Control image saved to {out / 'control_image.png'}")
    else:
        tmpl = annotations_to_template(src_annotations, W, H, reference_image=ref_img)
        tmpl.save(out / "template_image.png")
        print(f"  Template image saved to {out / 'template_image.png'}")

    ann_id_counter = 1

    for i in range(n_samples):
        seed = random.randint(0, 2**32 - 1)
        print(f"\n{'='*60}")
        print(f"  Sample {i+1}/{n_samples}  seed={seed}  method={method}")
        print(f"{'='*60}")

        if method == "controlnet":
            synth_img = generate_with_controlnet(
                annotations=src_annotations,
                width=W,
                height=H,
                prompt=prompt,
                negative_prompt=negative_prompt,
                controlnet_conditioning_scale=controlnet_conditioning_scale,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps,
                seed=seed,
                device=device,
                sd_model_id=sd_model_id,
                controlnet_model_id=controlnet_model_id,
            )
        else:
            synth_img = generate_with_img2img(
                annotations=src_annotations,
                width=W,
                height=H,
                reference_image=ref_img,
                prompt=prompt,
                negative_prompt=negative_prompt,
                strength=img2img_strength,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps,
                seed=seed,
                device=device,
                model_id=img2img_model_id,
            )

        # ── Save ──
        fname = f"synth_{i:04d}.png"
        synth_img.save(out / "images" / fname)

        img_id = i + 1000
        synth_coco["images"].append({
            "id": img_id,
            "file_name": fname,
            "width": W,
            "height": H,
        })

        # Annotations are the same layout we requested
        for ann in src_annotations:
            new_ann = copy.deepcopy(ann)
            new_ann["id"] = ann_id_counter
            new_ann["image_id"] = img_id
            bx, by, bw, bh = new_ann["bbox"]
            new_ann["area"] = round(bw * bh, 2)
            new_ann["iscrowd"] = 0
            synth_coco["annotations"].append(new_ann)
            ann_id_counter += 1

        print(f"  ✓ {fname}  ({len(src_annotations)} bands)")

    ann_path = out / "annotations.json"
    with open(ann_path, "w") as f:
        json.dump(synth_coco, f, indent=2)
    print(f"\n✅ Done: {n_samples} images → {out}")
    print(f"   Annotations: {ann_path}")
    return synth_coco


# %% ── 6. Variation: perturb annotations to create new layouts ────────────────

def perturb_annotations(
    annotations: list[dict],
    image_width: int,
    image_height: int,
    position_jitter: float = 0.1,
    size_jitter: float = 0.15,
    drop_prob: float = 0.1,
    add_prob: float = 0.05,
) -> list[dict]:
    """
    Create a variation of the annotation layout by randomly:
    - Shifting band positions (position_jitter as fraction of bbox size)
    - Scaling band sizes
    - Dropping some bands
    - Duplicating some bands nearby

    This lets you generate images with DIFFERENT layouts from a single
    set of seed annotations.
    """
    new_anns = []
    next_id = max(a["id"] for a in annotations) + 1

    for ann in annotations:
        # Random drop
        if random.random() < drop_prob:
            continue

        x, y, w, h = ann["bbox"]

        # Position jitter
        dx = random.gauss(0, position_jitter * w)
        dy = random.gauss(0, position_jitter * h)

        # Size jitter
        sw = random.uniform(1 - size_jitter, 1 + size_jitter)
        sh = random.uniform(1 - size_jitter, 1 + size_jitter)

        # Score jitter
        new_score = min(1.0, max(0.3, ann.get("score", 1.0) + random.gauss(0, 0.1)))

        nx = max(0, min(image_width - w * sw, x + dx))
        ny = max(0, min(image_height - h * sh, y + dy))

        new_ann = copy.deepcopy(ann)
        new_ann["bbox"] = [round(nx, 1), round(ny, 1),
                           round(w * sw, 1), round(h * sh, 1)]
        new_ann["score"] = round(new_score, 2)
        new_anns.append(new_ann)

        # Random duplication nearby
        if random.random() < add_prob:
            dup = copy.deepcopy(new_ann)
            dup["id"] = next_id
            next_id += 1
            dup_x = nx + random.gauss(0, w * 0.3)
            dup_y = ny + random.gauss(0, h * 2)
            dup["bbox"] = [round(max(0, dup_x), 1), round(max(0, dup_y), 1),
                           round(w * sw * random.uniform(0.8, 1.0), 1),
                           round(h * sh * random.uniform(0.8, 1.0), 1)]
            dup["score"] = round(random.uniform(0.4, 0.8), 2)
            new_anns.append(dup)

    return new_anns


def generate_varied_dataset(
    coco_dict: dict,
    output_dir: str = "synthetic_varied",
    n_samples: int = 10,
    perturb: bool = True,
    **kwargs,
) -> dict:
    """
    Like generate_synthetic_dataset, but optionally perturbs the annotation
    layout for each sample so you get diverse band arrangements.

    Pass perturb=True to randomise layouts, or perturb=False to keep the
    exact same layout (same as generate_synthetic_dataset).
    """
    out = Path(output_dir)
    out.mkdir(parents=True, exist_ok=True)
    (out / "images").mkdir(exist_ok=True)

    img_info = coco_dict["images"][0]
    W, H = img_info["width"], img_info["height"]
    src_annotations = [a for a in coco_dict["annotations"]
                       if a["image_id"] == img_info["id"]]

    method = kwargs.get("method", "controlnet")
    device = kwargs.get("device", "cuda")
    ref_path = kwargs.get("reference_image_path")
    ref_img = Image.open(ref_path) if ref_path else None

    prompt = kwargs.get("prompt") or (
        "a high quality grayscale western blot gel electrophoresis image, "
        "dark protein bands on light background, sharp bands, "
        "scientific laboratory photograph, high contrast"
    )
    negative_prompt = kwargs.get("negative_prompt") or (
        "color, text, labels, annotations, bounding boxes, cartoon, "
        "drawing, blurry, low quality, watermark"
    )

    synth_coco = {
        "info": {**coco_dict["info"], "description": "Synthetic Varied " + coco_dict["info"]["description"]},
        "images": [],
        "annotations": [],
        "categories": copy.deepcopy(coco_dict["categories"]),
    }
    ann_id_counter = 1

    for i in range(n_samples):
        seed = random.randint(0, 2**32 - 1)

        # Perturb layout for this sample
        if perturb:
            sample_anns = perturb_annotations(src_annotations, W, H)
        else:
            sample_anns = copy.deepcopy(src_annotations)

        print(f"\n[{i+1}/{n_samples}] seed={seed}  bands={len(sample_anns)}")

        if method == "controlnet":
            synth_img = generate_with_controlnet(
                annotations=sample_anns, width=W, height=H,
                prompt=prompt, negative_prompt=negative_prompt,
                controlnet_conditioning_scale=kwargs.get("controlnet_conditioning_scale", 1.0),
                guidance_scale=kwargs.get("guidance_scale", 7.5),
                num_inference_steps=kwargs.get("num_inference_steps", 50),
                seed=seed, device=device,
                sd_model_id=kwargs.get("sd_model_id", "runwayml/stable-diffusion-v1-5"),
                controlnet_model_id=kwargs.get("controlnet_model_id", "lllyasviel/control_v11p_sd15_scribble"),
            )
        else:
            synth_img = generate_with_img2img(
                annotations=sample_anns, width=W, height=H,
                reference_image=ref_img, prompt=prompt,
                negative_prompt=negative_prompt,
                strength=kwargs.get("img2img_strength", 0.55),
                guidance_scale=kwargs.get("guidance_scale", 7.5),
                num_inference_steps=kwargs.get("num_inference_steps", 50),
                seed=seed, device=device,
                model_id=kwargs.get("img2img_model_id", "stabilityai/stable-diffusion-2-1"),
            )

        fname = f"synth_{i:04d}.png"
        synth_img.save(out / "images" / fname)

        img_id = i + 2000
        synth_coco["images"].append({
            "id": img_id, "file_name": fname, "width": W, "height": H,
        })

        for ann in sample_anns:
            new_ann = copy.deepcopy(ann)
            new_ann["id"] = ann_id_counter
            new_ann["image_id"] = img_id
            bw, bh = new_ann["bbox"][2], new_ann["bbox"][3]
            new_ann["area"] = round(bw * bh, 2)
            new_ann["iscrowd"] = 0
            synth_coco["annotations"].append(new_ann)
            ann_id_counter += 1

    ann_path = out / "annotations.json"
    with open(ann_path, "w") as f:
        json.dump(synth_coco, f, indent=2)
    print(f"\n✅ {n_samples} varied images → {out}")
    return synth_coco


# %% ── 7. Visualization ──────────────────────────────────────────────────────

def show_pipeline(
    annotations: list[dict],
    width: int,
    height: int,
    reference_image: Optional[Image.Image] = None,
):
    """Visualize the conditioning images that drive generation."""
    import matplotlib.pyplot as plt

    scribble = annotations_to_scribble_map(annotations, width, height)
    control = annotations_to_control_image(annotations, width, height)
    template = annotations_to_template(annotations, width, height, reference_image)

    fig, axes = plt.subplots(1, 3, figsize=(18, 7))
    axes[0].imshow(scribble, cmap="gray")
    axes[0].set_title("Scribble Map\n(ControlNet input)")
    axes[1].imshow(control, cmap="gray")
    axes[1].set_title("Intensity Control\n(band strength)")
    axes[2].imshow(template, cmap="gray")
    axes[2].set_title("Rendered Template\n(img2img input)")
    for ax in axes:
        ax.axis("off")
    plt.tight_layout()
    plt.show()


def show_comparison(synth_dir: str, n: int = 4, reference_path: Optional[str] = None):
    """Display generated samples with their annotations."""
    import matplotlib.pyplot as plt

    synth_path = Path(synth_dir)
    with open(synth_path / "annotations.json") as f:
        synth_coco = json.load(f)

    cols = min(n + (1 if reference_path else 0), 5)
    fig, axes = plt.subplots(1, cols, figsize=(5 * cols, 6))
    if cols == 1:
        axes = [axes]

    offset = 0
    if reference_path:
        ref = Image.open(reference_path)
        axes[0].imshow(ref, cmap="gray")
        axes[0].set_title("Reference", fontweight="bold")
        axes[0].axis("off")
        offset = 1

    for i, img_info in enumerate(synth_coco["images"][:n]):
        img = Image.open(synth_path / "images" / img_info["file_name"])
        anns = [a for a in synth_coco["annotations"]
                if a["image_id"] == img_info["id"]]
        axes[i + offset].imshow(visualize_annotations(img, anns))
        axes[i + offset].set_title(f"Generated #{i+1}")
        axes[i + offset].axis("off")

    plt.tight_layout()
    plt.show()
