"""
PyTorch Training Script for Western Blot Band Detection (Memory Optimized)

Models (from lightest to heaviest):
    - ssd: SSDLite MobileNetV3 (fastest, least memory)
    - fasterrcnn_mobilenet: Faster R-CNN MobileNetV3 (good balance)
    - fasterrcnn: Faster R-CNN ResNet50 (most accurate, most memory)

Usage:
    # Lightest model (recommended if running out of memory)
    python train.py --data ./dataset --model ssd --batch-size 4

    # Medium model
    python train.py --data ./dataset --model fasterrcnn_mobilenet --batch-size 2

    # If still OOM, reduce image size
    python train.py --data ./dataset --model ssd --batch-size 2 --image-size 480
"""

import argparse
import json
from pathlib import Path

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from PIL import Image
import numpy as np


class WesternBlotDataset(Dataset):
    """Dataset with image resizing to save memory."""
    
    def __init__(self, root_dir: str, split: str = 'train', transforms=None, max_size: int = 640):
        self.root_dir = Path(root_dir)
        self.split = split
        self.transforms = transforms
        self.max_size = max_size
        
        ann_path = self.root_dir / split / 'annotations.json'
        with open(ann_path, 'r') as f:
            self.coco = json.load(f)
        
        self.images = self.coco['images']
        self.annotations_by_image = {}
        for ann in self.coco['annotations']:
            img_id = ann['image_id']
            if img_id not in self.annotations_by_image:
                self.annotations_by_image[img_id] = []
            self.annotations_by_image[img_id].append(ann)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_info = self.images[idx]
        img_id = img_info['id']
        
        img_path = self.root_dir / self.split / 'images' / img_info['file_name']
        image = Image.open(img_path).convert('RGB')
        
        orig_w, orig_h = image.size
        
        # Resize to save memory
        scale = min(self.max_size / orig_w, self.max_size / orig_h, 1.0)
        if scale < 1:
            new_w, new_h = int(orig_w * scale), int(orig_h * scale)
            image = image.resize((new_w, new_h), Image.BILINEAR)
        
        anns = self.annotations_by_image.get(img_id, [])
        
        boxes = []
        labels = []
        areas = []
        
        for ann in anns:
            x, y, w, h = ann['bbox']
            x, y, w, h = x * scale, y * scale, w * scale, h * scale
            if w > 1 and h > 1:  # Skip tiny boxes
                boxes.append([x, y, x + w, y + h])
                labels.append(1)
                areas.append(w * h)
        
        if len(boxes) == 0:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
            areas = torch.zeros((0,), dtype=torch.float32)
        else:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            areas = torch.as_tensor(areas, dtype=torch.float32)
        
        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': torch.tensor([img_id]),
            'area': areas,
            'iscrowd': torch.zeros((len(boxes),), dtype=torch.int64)
        }
        
        if self.transforms:
            image = self.transforms(image)
        
        return image, target


def get_transform(train: bool):
    transforms_list = [T.ToTensor()]
    if train:
        transforms_list.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms_list)


def collate_fn(batch):
    return tuple(zip(*batch))


def get_model(model_type: str, num_classes: int = 2):
    """Get model - lighter models use less memory."""
    
    if model_type == 'ssd':
        # LIGHTEST - SSDLite with MobileNetV3 backbone
        from torchvision.models.detection import ssdlite320_mobilenet_v3_large
        from torchvision.models.detection.ssd import SSDClassificationHead
        model = ssdlite320_mobilenet_v3_large(weights='DEFAULT')
        in_channels = [672, 480, 512, 256, 256, 128]
        num_anchors = [6, 6, 6, 6, 6, 6]
        model.head.classification_head = SSDClassificationHead(
            in_channels, num_anchors, num_classes
        )
        return model
    
    elif model_type == 'fasterrcnn_mobilenet':
        # MEDIUM - Faster R-CNN with MobileNetV3 backbone
        from torchvision.models.detection import fasterrcnn_mobilenet_v3_large_fpn
        from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
        model = fasterrcnn_mobilenet_v3_large_fpn(weights='DEFAULT')
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
        return model
    
    elif model_type == 'fasterrcnn':
        # HEAVIEST - Faster R-CNN with ResNet50 backbone
        from torchvision.models.detection import fasterrcnn_resnet50_fpn
        from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
        model = fasterrcnn_resnet50_fpn(weights='DEFAULT')
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
        return model
    
    else:
        raise ValueError(f"Unknown model: {model_type}")


def train(
    data_dir: str,
    model_type: str = 'fasterrcnn_mobilenet',
    epochs: int = 50,
    batch_size: int = 2,
    lr: float = 0.005,
    device: str = 'auto',
    output_dir: str = './output',
    image_size: int = 640,
    use_amp: bool = True
):
    """Train with memory optimizations."""
    
    if device == 'auto':
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    use_amp = use_amp and device == 'cuda'
    
    print(f"{'='*50}")
    print(f"Training Configuration:")
    print(f"  Device: {device}")
    print(f"  Model: {model_type}")
    print(f"  Batch size: {batch_size}")
    print(f"  Image size: {image_size}")
    print(f"  Mixed precision: {use_amp}")
    print(f"{'='*50}\n")
    
    # Datasets
    train_dataset = WesternBlotDataset(
        data_dir, 'train', get_transform(train=True), max_size=image_size
    )
    val_dataset = WesternBlotDataset(
        data_dir, 'val', get_transform(train=False), max_size=image_size
    )
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=0, collate_fn=collate_fn, pin_memory=False
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=0, collate_fn=collate_fn, pin_memory=False
    )
    
    print(f"Train images: {len(train_dataset)}")
    print(f"Val images: {len(val_dataset)}\n")
    
    # Model
    print(f"Loading {model_type}...")
    model = get_model(model_type, num_classes=2)
    model.to(device)
    
    # Optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=0.0005)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    
    # Mixed precision
    scaler = torch.cuda.amp.GradScaler() if use_amp else None
    
    # Output
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    best_loss = float('inf')
    
    print("Starting training...\n")
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        
        for batch_idx, (images, targets) in enumerate(train_loader):
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            optimizer.zero_grad()
            
            if use_amp:
                with torch.cuda.amp.autocast():
                    loss_dict = model(images, targets)
                    losses = sum(loss for loss in loss_dict.values())
                scaler.scale(losses).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss_dict = model(images, targets)
                losses = sum(loss for loss in loss_dict.values())
                losses.backward()
                optimizer.step()
            
            epoch_loss += losses.item()
            
            # Clear cache periodically
            if device == 'cuda' and batch_idx % 20 == 0:
                torch.cuda.empty_cache()
        
        lr_scheduler.step()
        
        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
        
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), output_path / 'best_model.pt')
            print(f"  -> Saved best model")
        
        if device == 'cuda':
            torch.cuda.empty_cache()
    
    torch.save(model.state_dict(), output_path / 'final_model.pt')
    
    # Save model info for inference
    with open(output_path / 'model_info.json', 'w') as f:
        json.dump({'model_type': model_type, 'image_size': image_size}, f)
    
    print(f"\nTraining complete! Models saved to: {output_path}")
    return model


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train Western Blot detector')
    parser.add_argument('--data', type=str, required=True, help='Dataset directory')
    parser.add_argument('--model', type=str, default='fasterrcnn_mobilenet',
                        choices=['ssd', 'fasterrcnn_mobilenet', 'fasterrcnn'],
                        help='Model (ssd=lightest, fasterrcnn=heaviest)')
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--batch-size', type=int, default=2, help='Reduce if OOM')
    parser.add_argument('--lr', type=float, default=0.005)
    parser.add_argument('--device', type=str, default='auto')
    parser.add_argument('--output', type=str, default='./output')
    parser.add_argument('--image-size', type=int, default=640, help='Reduce if OOM')
    parser.add_argument('--no-amp', action='store_true', help='Disable mixed precision')
    
    args = parser.parse_args()
    
    train(
        data_dir=args.data,
        model_type=args.model,
        epochs=args.epochs,
        batch_size=args.batch_size,
        lr=args.lr,
        device=args.device,
        output_dir=args.output,
        image_size=args.image_size,
        use_amp=not args.no_amp
    )
