"""
Western Blot Synthetic Image Generator with YOLO Annotations

Features:
- Single class annotation (band)
- Row-aligned bands across lanes (realistic protein migration)
- Row tilt (slight upward or downward slope)
- Ladder/marker can be first OR last lane
- Merged overlapping annotations (one annotation per blob)
- Irregular loading (varying intensity across lanes)
- Faint/thin bands that are NOT annotated

Author: Generated for YOLO training
"""

import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Optional, List, Tuple, Dict
import yaml


def generate_western_blot_with_annotations(
    width: int = 800,
    height: int = 500,
    num_lanes: int = 4,
    include_ladder: bool = True,
    ladder_position: str = 'auto',
    num_protein_rows: Tuple[int, int] = (2, 6),
    bands_per_row_probability: float = 0.7,
    background_intensity: int = 200,
    noise_level: float = 15,
    band_blur: Tuple[int, int] = (3, 15),  # More horizontal blur
    row_tilt_range: Tuple[float, float] = (-15, 15),
    skew_range: Tuple[float, float] = (-8, 8),
    irregular_loading: bool = True,
    add_faint_bands: bool = True,
    faint_bands_per_lane: Tuple[int, int] = (0, 2),
    add_border: bool = True,
    border_color: str = 'auto',  # 'white', 'black', or 'auto' (random)
    seed: Optional[int] = None
) -> Tuple[np.ndarray, List[Dict]]:
    """
    Generate a synthetic Western blot image with row-aligned bands.
    Bands are horizontal line-like shapes (wide and thin).
    All bands are annotated (including faint ones).
    Gel background is always gray/noisy, borders can be white or black.
    """
    if seed is not None:
        np.random.seed(seed)
    
    # Determine border color (white or black)
    if border_color == 'auto':
        use_black_border = np.random.random() > 0.5
    else:
        use_black_border = (border_color == 'black')
    
    # Create gray gel background (always gray, never fully white/black)
    img = np.ones((height, width), dtype=np.float32) * background_intensity
    
    # Add subtle gradients
    v_gradient = np.linspace(0, 15, height).reshape(-1, 1)
    img += v_gradient
    h_gradient = np.sin(np.linspace(0, np.pi, width)) * 8
    img += h_gradient
    
    annotations = []
    
    # Determine ladder position
    if ladder_position == 'auto':
        ladder_position = 'first' if np.random.random() > 0.5 else 'last'
    
    # Calculate lane positions
    total_lanes = num_lanes + (1 if include_ladder else 0)
    lane_margin = width * 0.08
    lane_spacing = (width - 2 * lane_margin) / (total_lanes - 1) if total_lanes > 1 else 0
    lane_x_positions = [int(lane_margin + i * lane_spacing) for i in range(total_lanes)]
    
    # Assign which lanes are sample vs ladder
    if include_ladder:
        if ladder_position == 'first':
            ladder_lane_idx = 0
            sample_lane_indices = list(range(1, total_lanes))
        else:
            ladder_lane_idx = total_lanes - 1
            sample_lane_indices = list(range(0, total_lanes - 1))
    else:
        ladder_lane_idx = None
        sample_lane_indices = list(range(total_lanes))
    
    # Lane loading multipliers
    if irregular_loading:
        lane_multipliers = np.random.uniform(0.6, 1.2, total_lanes)
    else:
        lane_multipliers = np.ones(total_lanes)
    
    # Generate protein row positions
    num_rows = np.random.randint(num_protein_rows[0], num_protein_rows[1] + 1)
    row_y_positions = np.sort(np.random.uniform(0.15, 0.85, num_rows))
    
    # Generate tilt for each row
    row_tilts = np.random.uniform(row_tilt_range[0], row_tilt_range[1], num_rows)
    
    # Track which (row, lane) positions have bands - ONE band per position
    occupied_positions = set()
    
    # Generate bands for each row across sample lanes
    for row_idx, (rel_y, tilt) in enumerate(zip(row_y_positions, row_tilts)):
        base_y = int(height * rel_y)
        # Bands are WIDE and THIN (line-like)
        base_band_height = np.random.randint(8, 18)  # Thin
        base_band_width_factor = np.random.uniform(0.7, 0.95)  # Wide relative to lane
        base_intensity = np.random.uniform(0.5, 1.0)
        
        for lane_idx in sample_lane_indices:
            # Skip if random chance says no band here
            if np.random.random() > bands_per_row_probability:
                continue
            
            # Skip if this position already has a band
            pos_key = (row_idx, lane_idx)
            if pos_key in occupied_positions:
                continue
            occupied_positions.add(pos_key)
            
            lane_x = lane_x_positions[lane_idx]
            lane_mult = lane_multipliers[lane_idx]
            
            tilt_offset = tilt * (lane_x - width / 2) / width
            y_pos = int(base_y + tilt_offset)
            
            intensity = np.clip(base_intensity * lane_mult * np.random.uniform(0.85, 1.15), 0, 1)
            band_height = int(base_band_height * np.random.uniform(0.8, 1.2))
            band_width = int(lane_spacing * base_band_width_factor * np.random.uniform(0.85, 1.1))
            
            img = _add_band(img, lane_x, y_pos, band_width, band_height, intensity)
            
            # Annotate ALL bands
            annotations.append({
                'class_id': 0,
                'x_center': lane_x / width,
                'y_center': y_pos / height,
                'width': (band_width * 1.1) / width,
                'height': (band_height * 2.8) / height
            })
    
    # Generate ladder bands (also line-like)
    if include_ladder and ladder_lane_idx is not None:
        ladder_x = lane_x_positions[ladder_lane_idx]
        ladder_mult = lane_multipliers[ladder_lane_idx]
        ladder_positions = np.linspace(0.08, 0.92, np.random.randint(8, 14))
        
        for rel_y in ladder_positions:
            y_pos = int(height * rel_y)
            intensity = np.random.uniform(0.5, 0.85) * ladder_mult
            band_height = np.random.randint(6, 14)  # Thin
            band_width = np.random.randint(int(lane_spacing * 0.5), int(lane_spacing * 0.8))  # Wide
            
            img = _add_band(img, ladder_x, y_pos, band_width, band_height, intensity)
            
            # Annotate ALL bands
            annotations.append({
                'class_id': 0,
                'x_center': ladder_x / width,
                'y_center': y_pos / height,
                'width': (band_width * 1.2) / width,
                'height': (band_height * 3) / height
            })
    
    # Add additional faint bands (also annotated now)
    if add_faint_bands:
        for lane_idx in sample_lane_indices:
            lane_x = lane_x_positions[lane_idx]
            num_faint = np.random.randint(faint_bands_per_lane[0], faint_bands_per_lane[1] + 1)
            
            for _ in range(num_faint):
                y_pos = int(height * np.random.uniform(0.1, 0.9))
                
                # Faint bands - lower intensity
                intensity = np.random.uniform(0.15, 0.4)
                band_height = np.random.randint(5, 12)
                band_width = np.random.randint(int(lane_spacing * 0.5), int(lane_spacing * 0.8))
                
                img = _add_band(img, lane_x, y_pos, band_width, band_height, intensity)
                
                # Annotate ALL bands including faint ones
                annotations.append({
                    'class_id': 0,
                    'x_center': lane_x / width,
                    'y_center': y_pos / height,
                    'width': (band_width * 1.2) / width,
                    'height': (band_height * 3) / height
                })
    
    # Apply blur (more horizontal for line-like bands)
    ksize_w = band_blur[0] if band_blur[0] % 2 == 1 else band_blur[0] + 1
    ksize_h = band_blur[1] if band_blur[1] % 2 == 1 else band_blur[1] + 1
    img = cv2.GaussianBlur(img, (ksize_w, ksize_h), 0)
    
    # Add noise
    noise = np.random.normal(0, noise_level, img.shape)
    img = img + noise
    
    # Speckle noise
    speckle_mask = np.random.random(img.shape) < 0.002
    img[speckle_mask] = img[speckle_mask] * np.random.uniform(0.5, 1.5, speckle_mask.sum())
    
    # Clip and convert
    img = np.clip(img, 0, 255).astype(np.uint8)
    
    # Apply global skew (rotation) - border is added with the appropriate color
    skew_angle = np.random.uniform(skew_range[0], skew_range[1])
    if abs(skew_angle) > 0.5:
        img, annotations = _apply_rotation(img, skew_angle, annotations, use_black_border)
    
    # Add random border (white or black)
    if add_border:
        img, annotations = _add_border(img, annotations, use_black_border)
    
    return img, annotations


def _add_band(img, x, y, width, height, intensity):
    """Add a single dark band with Gaussian falloff on gray background."""
    img_h, img_w = img.shape
    y_coords, x_coords = np.ogrid[0:img_h, 0:img_w]
    
    sigma_x = width / 3
    sigma_y = height / 3
    
    gaussian = np.exp(-((x_coords - x)**2 / (2 * sigma_x**2) + 
                        (y_coords - y)**2 / (2 * sigma_y**2)))
    
    # Dark bands on gray background
    darkness = intensity * 180
    img = img - gaussian * darkness
    
    return img


def _apply_rotation(img, angle, annotations, use_black_border=False):
    """Apply rotation to image and adjust annotations."""
    h, w = img.shape[:2]
    
    # Rotation matrix around center
    center = (w / 2, h / 2)
    M = cv2.getRotationMatrix2D(center, angle, 1.0)
    
    # Border color (areas revealed by rotation)
    border_value = 0 if use_black_border else 255
    
    # Rotate image
    rotated = cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_CONSTANT, borderValue=border_value)
    
    # Adjust annotations
    adjusted_annotations = []
    cos_a = np.cos(np.radians(-angle))
    sin_a = np.sin(np.radians(-angle))
    
    for ann in annotations:
        # Convert to pixel coordinates
        x = ann['x_center'] * w - center[0]
        y = ann['y_center'] * h - center[1]
        
        # Rotate point
        new_x = x * cos_a - y * sin_a + center[0]
        new_y = x * sin_a + y * cos_a + center[1]
        
        # Check if still in bounds
        if 0 < new_x < w and 0 < new_y < h:
            new_ann = ann.copy()
            new_ann['x_center'] = new_x / w
            new_ann['y_center'] = new_y / h
            adjusted_annotations.append(new_ann)
    
    return rotated, adjusted_annotations


def _add_border(img, annotations, use_black_border=False):
    """Add random border/corner regions (white or black) to simulate real Western blot edges."""
    h, w = img.shape[:2]
    
    # Border color
    border_color = 0 if use_black_border else 255
    
    # Randomly choose border style
    border_type = np.random.choice(['corner', 'side', 'irregular', 'none'], p=[0.3, 0.2, 0.3, 0.2])
    
    if border_type == 'none':
        return img, annotations
    
    mask = np.ones((h, w), dtype=np.float32)
    
    if border_type == 'corner':
        # Corner cut (like in the reference image)
        corner = np.random.choice(['top_right', 'top_left', 'bottom_right', 'bottom_left'])
        corner_size_x = np.random.randint(w // 6, w // 3)
        corner_size_y = np.random.randint(h // 4, h // 2)
        
        if corner == 'top_right':
            for y in range(h):
                for x in range(w):
                    if x > w - corner_size_x + (y * corner_size_x / corner_size_y):
                        if y < corner_size_y:
                            mask[y, x] = 0
        elif corner == 'top_left':
            for y in range(h):
                for x in range(w):
                    if x < corner_size_x - (y * corner_size_x / corner_size_y):
                        if y < corner_size_y:
                            mask[y, x] = 0
        elif corner == 'bottom_right':
            for y in range(h):
                for x in range(w):
                    if x > w - corner_size_x + ((h - y) * corner_size_x / corner_size_y):
                        if y > h - corner_size_y:
                            mask[y, x] = 0
        elif corner == 'bottom_left':
            for y in range(h):
                for x in range(w):
                    if x < corner_size_x - ((h - y) * corner_size_x / corner_size_y):
                        if y > h - corner_size_y:
                            mask[y, x] = 0
    
    elif border_type == 'side':
        # Side strip
        side = np.random.choice(['right', 'left', 'top', 'bottom'])
        strip_size = np.random.randint(w // 8, w // 4)
        
        if side == 'right':
            mask[:, -strip_size:] = 0
        elif side == 'left':
            mask[:, :strip_size] = 0
        elif side == 'top':
            mask[:strip_size, :] = 0
        elif side == 'bottom':
            mask[-strip_size:, :] = 0
    
    elif border_type == 'irregular':
        # Irregular region using random polygon
        num_points = np.random.randint(4, 8)
        edge = np.random.choice(['right', 'left', 'top', 'bottom'])
        points = []
        
        if edge == 'right':
            points.append((w, 0))
            points.append((w, h))
            for _ in range(num_points):
                points.append((w - np.random.randint(0, w // 3), np.random.randint(0, h)))
        elif edge == 'left':
            points.append((0, 0))
            points.append((0, h))
            for _ in range(num_points):
                points.append((np.random.randint(0, w // 3), np.random.randint(0, h)))
        elif edge == 'top':
            points.append((0, 0))
            points.append((w, 0))
            for _ in range(num_points):
                points.append((np.random.randint(0, w), np.random.randint(0, h // 3)))
        else:
            points.append((0, h))
            points.append((w, h))
            for _ in range(num_points):
                points.append((np.random.randint(0, w), h - np.random.randint(0, h // 3)))
        
        points = np.array(points, dtype=np.int32)
        cv2.fillPoly(mask, [points], 0)
    
    # Blur the mask for smooth transition
    mask = cv2.GaussianBlur(mask, (21, 21), 0)
    
    # Apply mask
    img_float = img.astype(np.float32)
    img_float = img_float * mask + border_color * (1 - mask)
    img = np.clip(img_float, 0, 255).astype(np.uint8)
    
    # Filter out annotations that fall in border regions
    filtered_annotations = []
    for ann in annotations:
        x = int(ann['x_center'] * w)
        y = int(ann['y_center'] * h)
        if 0 <= x < w and 0 <= y < h and mask[y, x] > 0.5:
            filtered_annotations.append(ann)
    
    return img, filtered_annotations


def annotation_to_yolo(annotation: Dict) -> str:
    """Convert annotation dict to YOLO format string."""
    return f"{annotation['class_id']} {annotation['x_center']:.6f} {annotation['y_center']:.6f} {annotation['width']:.6f} {annotation['height']:.6f}"


def visualize_annotations(img: np.ndarray, annotations: List[Dict], figsize=(12, 8)):
    """Visualize image with bounding boxes overlaid."""
    height, width = img.shape[:2]
    img_color = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    
    for ann in annotations:
        x_center = int(ann['x_center'] * width)
        y_center = int(ann['y_center'] * height)
        w = int(ann['width'] * width)
        h = int(ann['height'] * height)
        
        x1 = int(x_center - w / 2)
        y1 = int(y_center - h / 2)
        x2 = int(x_center + w / 2)
        y2 = int(y_center + h / 2)
        
        cv2.rectangle(img_color, (x1, y1), (x2, y2), (0, 255, 0), 2)
    
    plt.figure(figsize=figsize)
    plt.imshow(img_color)
    plt.title(f'Western Blot with {len(annotations)} annotated bands')
    plt.axis('off')
    return img_color
