"""
YOLO Dataset Generator for Western Blot Images

Usage:
    python generate_dataset.py --output ./dataset --train 500 --val 100 --test 100

This will create a YOLO-format dataset ready for training.
"""

import argparse
import cv2
import numpy as np
from pathlib import Path
import yaml
import json
from western_blot_generator import generate_western_blot_with_annotations, annotation_to_yolo


def generate_yolo_dataset(
    output_dir: str,
    num_train: int = 500,
    num_val: int = 100,
    num_test: int = 100,
    width: int = 800,
    height: int = 500,
    **kwargs
):
    """
    Generate a complete YOLO-format dataset.
    Also generates COCO-format annotations for PyTorch/HuggingFace compatibility.
    """
    output_path = Path(output_dir)
    
    # Create directory structure
    splits = ['train', 'val', 'test']
    counts = [num_train, num_val, num_test]
    
    for split in splits:
        (output_path / split / 'images').mkdir(parents=True, exist_ok=True)
        (output_path / split / 'labels').mkdir(parents=True, exist_ok=True)
    
    # Generate images for each split
    total_bands = 0
    
    for split, count in zip(splits, counts):
        print(f"Generating {count} {split} images...")
        
        # COCO format annotations
        coco_annotations = {
            'images': [],
            'annotations': [],
            'categories': [{'id': 0, 'name': 'band'}]
        }
        annotation_id = 0
        
        for i in range(count):
            # Vary parameters for diversity
            gen_kwargs = {
                'width': width,
                'height': height,
                'num_lanes': np.random.randint(3, 12),
                'num_protein_rows': (2, np.random.randint(4, 8)),
                'bands_per_row_probability': np.random.uniform(0.5, 0.9),
                'noise_level': np.random.uniform(8, 25),
                'background_intensity': np.random.randint(160, 230),
                'row_tilt_range': (-20, 20),
                'skew_range': (-10, 10),
                'irregular_loading': True,
                'add_faint_bands': True,
                'faint_bands_per_lane': (0, 3),
                'include_ladder': np.random.random() > 0.2,
                'ladder_position': 'auto',
                'add_border': True,
                'border_color': 'auto',  # Random white or black border
            }
            gen_kwargs.update(kwargs)
            
            img, annotations = generate_western_blot_with_annotations(**gen_kwargs)
            total_bands += len(annotations)
            
            # Save image
            img_filename = f"western_blot_{i:05d}.png"
            img_path = output_path / split / 'images' / img_filename
            cv2.imwrite(str(img_path), img)
            
            # Save YOLO annotations
            label_filename = f"western_blot_{i:05d}.txt"
            label_path = output_path / split / 'labels' / label_filename
            
            with open(label_path, 'w') as f:
                for ann in annotations:
                    f.write(annotation_to_yolo(ann) + '\n')
            
            # Add to COCO format
            coco_annotations['images'].append({
                'id': i,
                'file_name': img_filename,
                'width': width,
                'height': height
            })
            
            for ann in annotations:
                # Convert YOLO to COCO format (x, y, w, h in pixels)
                x_center = ann['x_center'] * width
                y_center = ann['y_center'] * height
                w = ann['width'] * width
                h = ann['height'] * height
                x = x_center - w / 2
                y = y_center - h / 2
                
                coco_annotations['annotations'].append({
                    'id': annotation_id,
                    'image_id': i,
                    'category_id': 0,
                    'bbox': [x, y, w, h],
                    'area': w * h,
                    'iscrowd': 0
                })
                annotation_id += 1
            
            if (i + 1) % 50 == 0:
                print(f"  {i + 1}/{count} complete...")
        
        # Save COCO annotations
        coco_path = output_path / split / 'annotations.json'
        with open(coco_path, 'w') as f:
            json.dump(coco_annotations, f)
    
    # Create data.yaml for YOLO
    data_yaml = {
        'path': str(output_path.absolute()),
        'train': 'train/images',
        'val': 'val/images',
        'test': 'test/images',
        'nc': 1,
        'names': ['band']
    }
    
    with open(output_path / 'data.yaml', 'w') as f:
        yaml.dump(data_yaml, f, default_flow_style=False)
    
    total_images = num_train + num_val + num_test
    print(f"\n{'='*60}")
    print(f"Dataset created at: {output_path.absolute()}")
    print(f"{'='*60}")
    print(f"  Total images: {total_images}")
    print(f"  Total annotated bands: {total_bands}")
    print(f"  Average bands per image: {total_bands / total_images:.1f}")
    print(f"\n  - Train: {num_train} images")
    print(f"  - Val:   {num_val} images")
    print(f"  - Test:  {num_test} images")
    print(f"\nFormats generated:")
    print(f"  - YOLO: data.yaml + labels/*.txt")
    print(f"  - COCO: annotations.json (for PyTorch/HuggingFace)")
    
    return output_path


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Generate Western Blot YOLO dataset')
    parser.add_argument('--output', type=str, default='./western_blot_dataset',
                        help='Output directory for dataset')
    parser.add_argument('--train', type=int, default=500,
                        help='Number of training images')
    parser.add_argument('--val', type=int, default=100,
                        help='Number of validation images')
    parser.add_argument('--test', type=int, default=100,
                        help='Number of test images')
    parser.add_argument('--width', type=int, default=800,
                        help='Image width')
    parser.add_argument('--height', type=int, default=500,
                        help='Image height')
    
    args = parser.parse_args()
    
    generate_yolo_dataset(
        output_dir=args.output,
        num_train=args.train,
        num_val=args.val,
        num_test=args.test,
        width=args.width,
        height=args.height
    )
