"""
Visualization script for previewing generated Western Blot images.

Usage:
    python visualize_samples.py --num 6 --output preview.png
"""

import argparse
import cv2
import numpy as np
import matplotlib.pyplot as plt
from western_blot_generator import generate_western_blot_with_annotations


def visualize_samples(num_samples: int = 6, output_path: str = 'preview.png', seed: int = None):
    """Generate and visualize sample Western Blots with annotations."""
    
    if seed is not None:
        np.random.seed(seed)
    
    cols = min(3, num_samples)
    rows = (num_samples + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows))
    if num_samples == 1:
        axes = np.array([axes])
    axes = axes.flatten()
    
    for i in range(num_samples):
        ax = axes[i]
        
        # Generate with varied parameters
        img, annotations = generate_western_blot_with_annotations(
            num_lanes=np.random.randint(4, 10),
            num_protein_rows=(2, np.random.randint(4, 7)),
            bands_per_row_probability=np.random.uniform(0.6, 0.9),
            noise_level=np.random.uniform(10, 20),
            background_intensity=np.random.randint(180, 220),
            row_tilt_range=(-18, 18),
            skew_range=(-8, 8),
            irregular_loading=True,
            add_faint_bands=True,
            faint_bands_per_lane=(0, 3),
            include_ladder=np.random.random() > 0.2,
            ladder_position='auto',
            add_border=True,
            border_color='auto'  # Random white or black border
        )
        
        # Draw annotations
        height, width = img.shape[:2]
        img_color = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        
        for ann in annotations:
            x_center = int(ann['x_center'] * width)
            y_center = int(ann['y_center'] * height)
            w = int(ann['width'] * width)
            h = int(ann['height'] * height)
            
            x1 = int(x_center - w / 2)
            y1 = int(y_center - h / 2)
            x2 = int(x_center + w / 2)
            y2 = int(y_center + h / 2)
            
            cv2.rectangle(img_color, (x1, y1), (x2, y2), (0, 255, 0), 2)
        
        ax.imshow(img_color)
        ax.set_title(f'Sample {i+1}: {len(annotations)} bands', fontsize=10)
        ax.axis('off')
    
    # Hide unused axes
    for i in range(num_samples, len(axes)):
        axes[i].axis('off')
    
    plt.suptitle('Synthetic Western Blots with Row-Aligned Bands\n' + 
                 'Green boxes = annotated bands (one per blob) | Faint bands ignored', 
                 fontsize=11, y=1.02)
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"Saved preview to: {output_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Visualize sample Western Blots')
    parser.add_argument('--num', type=int, default=6, help='Number of samples')
    parser.add_argument('--output', type=str, default='preview.png', help='Output file')
    parser.add_argument('--seed', type=int, default=None, help='Random seed')
    
    args = parser.parse_args()
    
    visualize_samples(num_samples=args.num, output_path=args.output, seed=args.seed)
