"""
Simple Auto-Label Western Blot Images (No GPU Required)

Uses classical computer vision to detect bands.
Good as a starting point or fallback when GPU is not available.

Supports two annotation modes:
  - BAND mode (default): One annotation per individual band
  - ROW mode: One annotation per horizontal row (groups bands at same Y position)

Installation:
    pip install opencv-python numpy

Usage:
    # Annotate individual bands (default)
    python auto_label_simple.py --image western.png
    
    # Annotate by ROWS
    python auto_label_simple.py --image western.png --mode row
    
    # Folder of images
    python auto_label_simple.py --input-dir ./images --output-dir ./labels
    
    # Adjust sensitivity
    python auto_label_simple.py --image western.png --min-area 500 --threshold 0.3
    
    # Row mode with custom tolerance
    python auto_label_simple.py --image western.png --mode row --row-tolerance 40
"""

import argparse
import json
from pathlib import Path
from typing import List, Tuple
import cv2
import numpy as np


def group_detections_into_rows(
    detections: List[dict],
    y_tolerance: int = 30
) -> List[dict]:
    """
    Group individual band detections into horizontal rows.
    
    Args:
        detections: List of individual band detections
        y_tolerance: Maximum Y distance (pixels) to consider bands in same row
        
    Returns:
        List of row detections (merged boxes)
    """
    if not detections:
        return []
    
    # Sort by Y center position
    sorted_dets = sorted(detections, key=lambda d: (d['box'][1] + d['box'][3]) / 2)
    
    rows = []
    current_row = [sorted_dets[0]]
    current_y_center = (sorted_dets[0]['box'][1] + sorted_dets[0]['box'][3]) / 2
    
    for det in sorted_dets[1:]:
        det_y_center = (det['box'][1] + det['box'][3]) / 2
        
        if abs(det_y_center - current_y_center) <= y_tolerance:
            current_row.append(det)
            current_y_center = np.mean([(d['box'][1] + d['box'][3]) / 2 for d in current_row])
        else:
            rows.append(current_row)
            current_row = [det]
            current_y_center = det_y_center
    
    rows.append(current_row)
    
    # Merge each row into single bounding box
    row_detections = []
    for row_idx, row in enumerate(rows):
        x1 = min(d['box'][0] for d in row)
        y1 = min(d['box'][1] for d in row)
        x2 = max(d['box'][2] for d in row)
        y2 = max(d['box'][3] for d in row)
        
        avg_score = np.mean([d['score'] for d in row])
        
        row_detections.append({
            'box': [x1, y1, x2, y2],
            'score': float(avg_score),
            'num_bands': len(row),
            'area': (x2 - x1) * (y2 - y1),
            'aspect_ratio': (x2 - x1) / (y2 - y1) if (y2 - y1) > 0 else 0
        })
    
    return row_detections


def detect_bands_adaptive(
    image: np.ndarray,
    min_area: int = 300,
    max_area: int = 50000,
    min_aspect_ratio: float = 1.5,  # Width/height - bands are wide
    max_aspect_ratio: float = 20.0,
    threshold_factor: float = 0.3
) -> List[dict]:
    """
    Detect bands using adaptive thresholding.
    
    Args:
        image: Input image (BGR or grayscale)
        min_area: Minimum band area
        max_area: Maximum band area
        min_aspect_ratio: Minimum width/height ratio
        max_aspect_ratio: Maximum width/height ratio
        threshold_factor: Lower = more sensitive
        
    Returns:
        List of detections with boxes
    """
    # Convert to grayscale
    if len(image.shape) == 3:
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    else:
        gray = image.copy()
    
    # Determine if dark bands on light background or vice versa
    mean_val = np.mean(gray)
    if mean_val > 127:
        # Light background - dark bands
        gray_inv = 255 - gray
    else:
        # Dark background - light bands
        gray_inv = gray
    
    # Apply Gaussian blur to reduce noise
    blurred = cv2.GaussianBlur(gray_inv, (5, 5), 0)
    
    # Adaptive thresholding
    thresh = cv2.adaptiveThreshold(
        blurred, 255,
        cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
        cv2.THRESH_BINARY,
        blockSize=51,
        C=-int(threshold_factor * 30)
    )
    
    # Morphological operations to clean up
    kernel_h = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 3))
    kernel_v = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
    
    # Close horizontal gaps (bands are horizontal)
    closed = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel_h)
    # Remove small vertical noise
    opened = cv2.morphologyEx(closed, cv2.MORPH_OPEN, kernel_v)
    
    # Find contours
    contours, _ = cv2.findContours(opened, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    detections = []
    for cnt in contours:
        area = cv2.contourArea(cnt)
        
        # Filter by area
        if area < min_area or area > max_area:
            continue
        
        # Get bounding box
        x, y, w, h = cv2.boundingRect(cnt)
        
        # Filter by aspect ratio (bands are wider than tall)
        aspect_ratio = w / h if h > 0 else 0
        if aspect_ratio < min_aspect_ratio or aspect_ratio > max_aspect_ratio:
            continue
        
        # Estimate confidence based on how "band-like" it is
        # Bands should be elongated horizontally
        confidence = min(1.0, aspect_ratio / 5.0) * min(1.0, area / 1000)
        confidence = min(0.99, confidence)
        
        detections.append({
            'box': [x, y, x + w, y + h],
            'score': confidence,
            'area': area,
            'aspect_ratio': aspect_ratio
        })
    
    # Sort by y position (top to bottom), then x (left to right)
    detections.sort(key=lambda d: (d['box'][1], d['box'][0]))
    
    return detections


def detect_bands_otsu(
    image: np.ndarray,
    min_area: int = 300,
    max_area: int = 50000,
    min_aspect_ratio: float = 1.5,
    max_aspect_ratio: float = 20.0
) -> List[dict]:
    """
    Detect bands using Otsu's thresholding.
    Works well when there's good contrast.
    """
    # Convert to grayscale
    if len(image.shape) == 3:
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    else:
        gray = image.copy()
    
    # Determine background type
    mean_val = np.mean(gray)
    if mean_val > 127:
        gray_inv = 255 - gray
    else:
        gray_inv = gray
    
    # Blur and threshold
    blurred = cv2.GaussianBlur(gray_inv, (5, 5), 0)
    _, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    
    # Morphological cleanup
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (10, 3))
    cleaned = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
    
    # Find contours
    contours, _ = cv2.findContours(cleaned, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    detections = []
    for cnt in contours:
        area = cv2.contourArea(cnt)
        if area < min_area or area > max_area:
            continue
        
        x, y, w, h = cv2.boundingRect(cnt)
        aspect_ratio = w / h if h > 0 else 0
        
        if aspect_ratio < min_aspect_ratio or aspect_ratio > max_aspect_ratio:
            continue
        
        confidence = min(0.99, aspect_ratio / 5.0 * area / 2000)
        
        detections.append({
            'box': [x, y, x + w, y + h],
            'score': confidence,
            'area': area,
            'aspect_ratio': aspect_ratio
        })
    
    detections.sort(key=lambda d: (d['box'][1], d['box'][0]))
    return detections


def auto_label_image(
    image_path: str,
    output_dir: str = None,
    method: str = 'adaptive',
    min_area: int = 300,
    max_area: int = 50000,
    min_aspect_ratio: float = 1.5,
    threshold_factor: float = 0.3,
    save_visualization: bool = True,
    mode: str = 'band',
    row_tolerance: int = 30
) -> dict:
    """
    Auto-label a Western blot image.
    
    Args:
        image_path: Path to input image
        output_dir: Output directory
        method: 'adaptive' or 'otsu'
        min_area: Minimum band area in pixels
        max_area: Maximum band area
        min_aspect_ratio: Minimum width/height ratio
        threshold_factor: Sensitivity (lower = more detections)
        save_visualization: Save image with boxes
        mode: 'band' for individual bands, 'row' for grouped rows
        row_tolerance: Y-distance tolerance for row grouping (pixels)
        
    Returns:
        Dict with results
    """
    image_path = Path(image_path)
    output_dir = Path(output_dir) if output_dir else image_path.parent / 'labels'
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Load image
    image = cv2.imread(str(image_path))
    if image is None:
        raise ValueError(f"Could not load image: {image_path}")
    
    h, w = image.shape[:2]
    
    print(f"Processing: {image_path.name} ({w}x{h}) [mode: {mode}]")
    
    # Detect individual bands first
    if method == 'adaptive':
        detections = detect_bands_adaptive(
            image,
            min_area=min_area,
            max_area=max_area,
            min_aspect_ratio=min_aspect_ratio,
            threshold_factor=threshold_factor
        )
    else:
        detections = detect_bands_otsu(
            image,
            min_area=min_area,
            max_area=max_area,
            min_aspect_ratio=min_aspect_ratio
        )
    
    print(f"  Found {len(detections)} individual bands")
    
    # Group into rows if requested
    if mode == 'row' and detections:
        detections = group_detections_into_rows(detections, y_tolerance=row_tolerance)
        print(f"  Grouped into {len(detections)} rows")
    
    # Save YOLO format
    yolo_lines = []
    for det in detections:
        x1, y1, x2, y2 = det['box']
        x_center = ((x1 + x2) / 2) / w
        y_center = ((y1 + y2) / 2) / h
        box_w = (x2 - x1) / w
        box_h = (y2 - y1) / h
        yolo_lines.append(f"0 {x_center:.6f} {y_center:.6f} {box_w:.6f} {box_h:.6f}")
    
    suffix = '_rows' if mode == 'row' else ''
    yolo_path = output_dir / f"{image_path.stem}{suffix}.txt"
    with open(yolo_path, 'w') as f:
        f.write('\n'.join(yolo_lines))
    print(f"  Saved: {yolo_path}")
    
    # Save JSON format
    category_name = 'row' if mode == 'row' else 'band'
    json_data = {
        'image': image_path.name,
        'width': w,
        'height': h,
        'mode': mode,
        'annotations': [
            {
                'bbox': det['box'],
                'score': det['score'],
                'category': category_name,
                'num_bands': det.get('num_bands', 1)
            }
            for det in detections
        ]
    }
    json_path = output_dir / f"{image_path.stem}{suffix}.json"
    with open(json_path, 'w') as f:
        json.dump(json_data, f, indent=2)
    print(f"  Saved: {json_path}")
    
    # Save visualization
    if save_visualization:
        vis = image.copy()
        
        if mode == 'row':
            # Different colors for each row
            colors = [
                (0, 0, 255), (0, 255, 0), (255, 0, 0), (0, 255, 255),
                (255, 0, 255), (255, 255, 0), (128, 0, 255), (255, 128, 0)
            ]
            for i, det in enumerate(detections):
                x1, y1, x2, y2 = [int(v) for v in det['box']]
                color = colors[i % len(colors)]
                cv2.rectangle(vis, (x1, y1), (x2, y2), color, 3)
                label = f"Row {i+1} ({det.get('num_bands', 1)})"
                cv2.putText(vis, label, (x1, y1-5),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
        else:
            for det in detections:
                x1, y1, x2, y2 = [int(v) for v in det['box']]
                score = det['score']
                color = (0, 255, 0) if score > 0.5 else (0, 255, 255)
                cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2)
                cv2.putText(vis, f"{score:.2f}", (x1, y1-5),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1)
        
        vis_path = output_dir / f"{image_path.stem}{suffix}_auto_labeled.png"
        cv2.imwrite(str(vis_path), vis)
        print(f"  Saved: {vis_path}")
    
    return {
        'image_path': str(image_path),
        'num_detections': len(detections),
        'detections': detections,
        'mode': mode
    }


def auto_label_directory(
    input_dir: str,
    output_dir: str = None,
    method: str = 'adaptive',
    min_area: int = 300,
    threshold_factor: float = 0.3,
    save_visualization: bool = True,
    mode: str = 'band',
    row_tolerance: int = 30
):
    """Auto-label all images in a directory."""
    input_dir = Path(input_dir)
    output_dir = Path(output_dir) if output_dir else input_dir / 'labels'
    output_dir.mkdir(parents=True, exist_ok=True)
    
    image_extensions = {'.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp'}
    image_paths = [p for p in input_dir.iterdir() if p.suffix.lower() in image_extensions]
    
    print(f"Found {len(image_paths)} images")
    print(f"Mode: {mode}" + (f" (tolerance: {row_tolerance}px)" if mode == 'row' else ''))
    print()
    
    total = 0
    for i, img_path in enumerate(image_paths, 1):
        print(f"[{i}/{len(image_paths)}] ", end='')
        result = auto_label_image(
            str(img_path),
            output_dir=str(output_dir),
            method=method,
            min_area=min_area,
            threshold_factor=threshold_factor,
            save_visualization=save_visualization,
            mode=mode,
            row_tolerance=row_tolerance
        )
        total += result['num_detections']
        print()
    
    label_type = "rows" if mode == 'row' else "bands"
    print(f"\nComplete!")
    print(f"  Images: {len(image_paths)}")
    print(f"  Total {label_type}: {total}")
    print(f"  Output: {output_dir}")
    print(f"\nNote: These are rough detections. Review and correct in:")
    print(f"  - CVAT (cvat.ai)")
    print(f"  - Roboflow (roboflow.com)")
    print(f"  - Label Studio (labelstud.io)")


def main():
    parser = argparse.ArgumentParser(
        description='Simple auto-labeling for Western blots (no GPU required)',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Annotate individual bands (default)
  python auto_label_simple.py --image blot.png
  
  # Annotate by ROWS (groups bands horizontally)
  python auto_label_simple.py --image blot.png --mode row
  
  # Folder of images
  python auto_label_simple.py --input-dir ./images --output-dir ./labels
  
  # Row mode with custom tolerance
  python auto_label_simple.py --image blot.png --mode row --row-tolerance 50
  
  # Adjust detection sensitivity
  python auto_label_simple.py --image blot.png --threshold 0.2 --min-area 200

Annotation Modes:
  - band (default): One box per individual band
  - row: One box per horizontal row (groups bands at similar Y positions)

This uses classical computer vision (no AI/GPU needed).
For better results, use auto_label.py with Grounding DINO.
        """
    )
    
    input_group = parser.add_mutually_exclusive_group(required=True)
    input_group.add_argument('--image', type=str, help='Single image')
    input_group.add_argument('--input-dir', type=str, help='Directory of images')
    
    parser.add_argument('--output-dir', type=str, default=None)
    parser.add_argument('--method', type=str, default='adaptive',
                        choices=['adaptive', 'otsu'])
    parser.add_argument('--min-area', type=int, default=300,
                        help='Minimum band area (pixels)')
    parser.add_argument('--max-area', type=int, default=50000,
                        help='Maximum band area (pixels)')
    parser.add_argument('--threshold', type=float, default=0.3,
                        help='Detection sensitivity (lower = more detections)')
    parser.add_argument('--no-vis', action='store_true')
    
    # Annotation mode
    parser.add_argument('--mode', type=str, default='band', choices=['band', 'row'],
                        help='Annotation mode: "band" for individual bands, "row" for grouped rows')
    parser.add_argument('--row-tolerance', type=int, default=30,
                        help='Y-distance tolerance for row grouping in pixels (default: 30)')
    
    args = parser.parse_args()
    
    if args.image:
        auto_label_image(
            args.image,
            output_dir=args.output_dir or './labels',
            method=args.method,
            min_area=args.min_area,
            max_area=args.max_area,
            threshold_factor=args.threshold,
            save_visualization=not args.no_vis,
            mode=args.mode,
            row_tolerance=args.row_tolerance
        )
    else:
        auto_label_directory(
            args.input_dir,
            output_dir=args.output_dir,
            method=args.method,
            min_area=args.min_area,
            threshold_factor=args.threshold,
            save_visualization=not args.no_vis,
            mode=args.mode,
            row_tolerance=args.row_tolerance
        )


if __name__ == "__main__":
    main()
