"""
Inference Script for Western Blot Band Detection

Supports both normal and low-memory modes.

Usage:
    # Normal inference
    python inference.py --model ./output/best_model.pt --image blot.png

    # Low memory mode (for limited GPU/RAM)
    python inference.py --model ./output/best_model.pt --image blot.png --low-memory

    # Batch inference on folder
    python inference.py --model ./output/best_model.pt --input-dir ./images --output-dir ./results

    # Adjust confidence threshold
    python inference.py --model ./output/best_model.pt --image blot.png --conf 0.3
"""

import argparse
import json
from pathlib import Path
from typing import List, Tuple, Optional
import gc

import torch
from torchvision import transforms as T
from PIL import Image
import numpy as np
import cv2


class WesternBlotDetector:
    """Western Blot band detector with normal and low-memory modes."""
    
    def __init__(
        self,
        model_path: str,
        model_type: str = None,
        device: str = 'auto',
        low_memory: bool = False
    ):
        """
        Initialize detector.
        
        Args:
            model_path: Path to .pt model file
            model_type: Model architecture ('ssd', 'fasterrcnn_mobilenet', 'fasterrcnn')
                       Auto-detected if model_info.json exists
            device: 'auto', 'cuda', or 'cpu'
            low_memory: Enable low-memory optimizations
        """
        self.low_memory = low_memory
        self.model_path = model_path
        
        # Setup device
        if device == 'auto':
            if low_memory:
                # Prefer CPU for low memory mode to avoid GPU memory issues
                self.device = 'cpu'
            else:
                self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        else:
            self.device = device
        
        # Auto-detect model type
        model_dir = Path(model_path).parent
        info_path = model_dir / 'model_info.json'
        if model_type is None and info_path.exists():
            with open(info_path) as f:
                info = json.load(f)
                model_type = info.get('model_type', 'fasterrcnn_mobilenet')
        elif model_type is None:
            model_type = 'fasterrcnn_mobilenet'
        
        self.model_type = model_type
        
        print(f"Detector Configuration:")
        print(f"  Model: {model_type}")
        print(f"  Device: {self.device}")
        print(f"  Low memory mode: {low_memory}")
        
        # Load model
        self.model = self._load_model()
        
    def _load_model(self):
        """Load model architecture and weights."""
        model = self._get_model_architecture(self.model_type)
        
        if self.low_memory:
            # Load weights with map_location to control memory
            state_dict = torch.load(
                self.model_path, 
                map_location='cpu',
                weights_only=True
            )
            model.load_state_dict(state_dict)
            del state_dict
            gc.collect()
        else:
            model.load_state_dict(
                torch.load(self.model_path, map_location=self.device)
            )
        
        model.to(self.device)
        model.eval()
        
        # Use half precision in low memory mode on GPU
        if self.low_memory and self.device == 'cuda':
            model.half()
        
        return model
    
    def _get_model_architecture(self, model_type: str):
        """Get model architecture."""
        num_classes = 2
        
        if model_type == 'ssd':
            from torchvision.models.detection import ssdlite320_mobilenet_v3_large
            from torchvision.models.detection.ssd import SSDClassificationHead
            model = ssdlite320_mobilenet_v3_large(weights=None)
            in_channels = [672, 480, 512, 256, 256, 128]
            num_anchors = [6, 6, 6, 6, 6, 6]
            model.head.classification_head = SSDClassificationHead(
                in_channels, num_anchors, num_classes
            )
            
        elif model_type == 'fasterrcnn_mobilenet':
            from torchvision.models.detection import fasterrcnn_mobilenet_v3_large_fpn
            from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
            model = fasterrcnn_mobilenet_v3_large_fpn(weights=None)
            in_features = model.roi_heads.box_predictor.cls_score.in_features
            model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
            
        elif model_type == 'fasterrcnn':
            from torchvision.models.detection import fasterrcnn_resnet50_fpn
            from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
            model = fasterrcnn_resnet50_fpn(weights=None)
            in_features = model.roi_heads.box_predictor.cls_score.in_features
            model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
            
        else:
            raise ValueError(f"Unknown model type: {model_type}")
        
        return model
    
    def _preprocess_image(
        self, 
        image: Image.Image, 
        max_size: Optional[int] = None
    ) -> Tuple[torch.Tensor, float]:
        """
        Preprocess image for inference.
        
        Returns:
            tensor: Preprocessed image tensor
            scale: Scale factor applied (for rescaling boxes)
        """
        orig_w, orig_h = image.size
        scale = 1.0
        
        # Resize if needed (for low memory mode)
        if max_size is not None:
            scale = min(max_size / orig_w, max_size / orig_h, 1.0)
            if scale < 1:
                new_w, new_h = int(orig_w * scale), int(orig_h * scale)
                image = image.resize((new_w, new_h), Image.BILINEAR)
        
        # Convert to tensor
        transform = T.ToTensor()
        tensor = transform(image)
        
        # Half precision for low memory GPU mode
        if self.low_memory and self.device == 'cuda':
            tensor = tensor.half()
        
        return tensor, scale
    
    def predict(
        self,
        image_path: str,
        conf_threshold: float = 0.5,
        max_size: Optional[int] = None
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Run inference on a single image.
        
        Args:
            image_path: Path to input image
            conf_threshold: Minimum confidence threshold
            max_size: Max image dimension (None for original size)
            
        Returns:
            boxes: Array of [x1, y1, x2, y2] boxes (in original image coordinates)
            scores: Confidence scores
            image: Original image as numpy array
        """
        # Load image
        image = Image.open(image_path).convert('RGB')
        orig_size = image.size
        
        # Set max_size for low memory mode
        if self.low_memory and max_size is None:
            max_size = 640
        
        # Preprocess
        tensor, scale = self._preprocess_image(image, max_size)
        tensor = tensor.unsqueeze(0).to(self.device)
        
        # Inference
        with torch.no_grad():
            if self.low_memory and self.device == 'cuda':
                with torch.cuda.amp.autocast():
                    predictions = self.model(tensor)[0]
            else:
                predictions = self.model(tensor)[0]
        
        # Filter by confidence
        scores = predictions['scores'].cpu().numpy()
        mask = scores > conf_threshold
        
        boxes = predictions['boxes'][mask].cpu().numpy()
        scores = scores[mask]
        
        # Rescale boxes to original image size
        if scale < 1:
            boxes = boxes / scale
        
        # Clear cache in low memory mode
        if self.low_memory and self.device == 'cuda':
            torch.cuda.empty_cache()
        
        return boxes, scores, np.array(image)
    
    def predict_batch(
        self,
        image_paths: List[str],
        conf_threshold: float = 0.5,
        max_size: Optional[int] = None
    ) -> List[Tuple[np.ndarray, np.ndarray]]:
        """
        Run inference on multiple images.
        
        In low memory mode, processes one image at a time.
        In normal mode, can batch for efficiency.
        """
        results = []
        
        for path in image_paths:
            boxes, scores, _ = self.predict(path, conf_threshold, max_size)
            results.append((boxes, scores))
            
            if self.low_memory:
                gc.collect()
                if self.device == 'cuda':
                    torch.cuda.empty_cache()
        
        return results


def visualize_predictions(
    image: np.ndarray,
    boxes: np.ndarray,
    scores: np.ndarray,
    output_path: str = None,
    show: bool = False
) -> np.ndarray:
    """Draw bounding boxes on image."""
    img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    
    for box, score in zip(boxes, scores):
        x1, y1, x2, y2 = box.astype(int)
        
        # Color based on confidence (green = high, yellow = medium, red = low)
        if score > 0.7:
            color = (0, 255, 0)  # Green
        elif score > 0.5:
            color = (0, 255, 255)  # Yellow
        else:
            color = (0, 165, 255)  # Orange
        
        cv2.rectangle(img_bgr, (x1, y1), (x2, y2), color, 2)
        
        # Label with score
        label = f'{score:.2f}'
        (label_w, label_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
        cv2.rectangle(img_bgr, (x1, y1 - label_h - 5), (x1 + label_w, y1), color, -1)
        cv2.putText(img_bgr, label, (x1, y1 - 3),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
    
    if output_path:
        cv2.imwrite(output_path, img_bgr)
        
    if show:
        cv2.imshow('Predictions', img_bgr)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
    
    return img_bgr


def save_predictions_json(
    boxes: np.ndarray,
    scores: np.ndarray,
    output_path: str,
    image_path: str = None
):
    """Save predictions to JSON format."""
    predictions = {
        'image': str(image_path) if image_path else None,
        'num_detections': len(boxes),
        'detections': [
            {
                'box': box.tolist(),
                'score': float(score),
                'class': 'band'
            }
            for box, score in zip(boxes, scores)
        ]
    }
    
    with open(output_path, 'w') as f:
        json.dump(predictions, f, indent=2)


def main():
    parser = argparse.ArgumentParser(
        description='Western Blot Band Detection Inference',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Single image inference
  python inference.py --model output/best_model.pt --image western.png
  
  # Low memory mode (for limited GPU/RAM)
  python inference.py --model output/best_model.pt --image western.png --low-memory
  
  # Batch inference on folder
  python inference.py --model output/best_model.pt --input-dir ./images --output-dir ./results
  
  # Save results as JSON
  python inference.py --model output/best_model.pt --image western.png --save-json
        """
    )
    
    # Model arguments
    parser.add_argument('--model', type=str, required=True,
                        help='Path to model .pt file')
    parser.add_argument('--type', type=str, default=None,
                        choices=['ssd', 'fasterrcnn_mobilenet', 'fasterrcnn'],
                        help='Model type (auto-detected if not specified)')
    
    # Input arguments
    parser.add_argument('--image', type=str, default=None,
                        help='Single input image path')
    parser.add_argument('--input-dir', type=str, default=None,
                        help='Directory of input images for batch processing')
    
    # Output arguments
    parser.add_argument('--output', type=str, default=None,
                        help='Output image path (default: prediction.png)')
    parser.add_argument('--output-dir', type=str, default=None,
                        help='Output directory for batch processing')
    parser.add_argument('--save-json', action='store_true',
                        help='Save predictions as JSON')
    
    # Inference arguments
    parser.add_argument('--conf', type=float, default=0.5,
                        help='Confidence threshold (default: 0.5)')
    parser.add_argument('--device', type=str, default='auto',
                        choices=['auto', 'cuda', 'cpu'],
                        help='Device for inference')
    
    # Memory arguments
    parser.add_argument('--low-memory', action='store_true',
                        help='Enable low memory mode')
    parser.add_argument('--max-size', type=int, default=None,
                        help='Max image dimension (auto in low-memory mode)')
    
    # Display
    parser.add_argument('--show', action='store_true',
                        help='Display result in window')
    parser.add_argument('--no-save', action='store_true',
                        help='Do not save output image')
    
    args = parser.parse_args()
    
    # Validate inputs
    if args.image is None and args.input_dir is None:
        parser.error('Must specify --image or --input-dir')
    
    # Initialize detector
    detector = WesternBlotDetector(
        model_path=args.model,
        model_type=args.type,
        device=args.device,
        low_memory=args.low_memory
    )
    
    # Single image inference
    if args.image:
        print(f"\nProcessing: {args.image}")
        
        boxes, scores, image = detector.predict(
            args.image,
            conf_threshold=args.conf,
            max_size=args.max_size
        )
        
        print(f"Detected {len(boxes)} bands (conf > {args.conf})")
        
        # Visualize
        output_path = args.output or 'prediction.png'
        if not args.no_save:
            visualize_predictions(image, boxes, scores, output_path)
            print(f"Saved: {output_path}")
        
        if args.show:
            visualize_predictions(image, boxes, scores, show=True)
        
        if args.save_json:
            json_path = Path(output_path).with_suffix('.json')
            save_predictions_json(boxes, scores, json_path, args.image)
            print(f"Saved: {json_path}")
    
    # Batch inference
    if args.input_dir:
        input_dir = Path(args.input_dir)
        output_dir = Path(args.output_dir or './predictions')
        output_dir.mkdir(parents=True, exist_ok=True)
        
        # Find all images
        image_extensions = {'.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp'}
        image_paths = [
            p for p in input_dir.iterdir()
            if p.suffix.lower() in image_extensions
        ]
        
        print(f"\nProcessing {len(image_paths)} images from {input_dir}")
        
        total_detections = 0
        for i, img_path in enumerate(image_paths):
            print(f"  [{i+1}/{len(image_paths)}] {img_path.name}...", end=' ')
            
            boxes, scores, image = detector.predict(
                str(img_path),
                conf_threshold=args.conf,
                max_size=args.max_size
            )
            
            total_detections += len(boxes)
            print(f"{len(boxes)} bands")
            
            # Save visualization
            output_path = output_dir / f"{img_path.stem}_pred.png"
            visualize_predictions(image, boxes, scores, str(output_path))
            
            # Save JSON if requested
            if args.save_json:
                json_path = output_dir / f"{img_path.stem}_pred.json"
                save_predictions_json(boxes, scores, str(json_path), str(img_path))
        
        print(f"\nBatch complete!")
        print(f"  Total images: {len(image_paths)}")
        print(f"  Total detections: {total_detections}")
        print(f"  Average per image: {total_detections / len(image_paths):.1f}")
        print(f"  Output directory: {output_dir}")


if __name__ == "__main__":
    main()
