"""
Auto-Label Western Blot Images using Grounding DINO

Uses zero-shot object detection to automatically find and label bands.
The generated labels can be reviewed/corrected in tools like CVAT, Roboflow, or Label Studio.

Features:
  - AI-powered band detection (Grounding DINO)
  - Auto-rotation to level tilted images
  - Band mode: One annotation per individual band
  - Row mode: One annotation per horizontal row

Installation:
    pip install torch torchvision
    pip install transformers
    pip install opencv-python
    pip install pillow

Usage:
    # Auto-label individual bands (default)
    python auto_label.py --image western.png

    # Auto-label by ROWS (groups bands horizontally)
    python auto_label.py --image western.png --mode row

    # Auto-rotate tilted images before labeling
    python auto_label.py --image western.png --auto-rotate

    # Combine auto-rotate with row mode
    python auto_label.py --image western.png --auto-rotate --mode row

    # Auto-label a folder of images  
    python auto_label.py --input-dir ./images --output-dir ./labels

    # Row mode with custom tolerance (how close Y positions must be to group)
    python auto_label.py --image western.png --mode row --row-tolerance 30

    # More detections (lower threshold)
    python auto_label.py --image western.png --threshold 0.15

    # Custom prompt
    python auto_label.py --image western.png --prompt "dark band, protein band"
"""

import argparse
import json
from pathlib import Path
from typing import List, Tuple
import cv2
import numpy as np
from PIL import Image


def auto_rotate_image(
    image: np.ndarray,
    max_angle: float = 15.0,
    method: str = 'hough'
) -> Tuple[np.ndarray, float]:
    """
    Automatically detect and correct image rotation.
    
    Western blot bands should be horizontal - this function detects
    the dominant angle and rotates to level the image.
    
    Args:
        image: Input image (RGB numpy array)
        max_angle: Maximum rotation correction in degrees (default: 15)
        method: Detection method ('hough' or 'moments')
        
    Returns:
        rotated_image: Corrected image
        angle: Rotation angle applied (degrees)
    """
    # Convert to grayscale
    if len(image.shape) == 3:
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    else:
        gray = image.copy()
    
    # Detect angle based on method
    if method == 'hough':
        angle = _detect_angle_hough(gray)
    else:
        angle = _detect_angle_moments(gray)
    
    # Limit rotation to max_angle
    if abs(angle) > max_angle:
        angle = 0  # Don't rotate if angle seems wrong
    
    if abs(angle) < 0.5:
        # No significant rotation needed
        return image, 0.0
    
    # Rotate image
    h, w = image.shape[:2]
    center = (w // 2, h // 2)
    
    # Get rotation matrix
    M = cv2.getRotationMatrix2D(center, angle, 1.0)
    
    # Calculate new image bounds to avoid cropping
    cos = np.abs(M[0, 0])
    sin = np.abs(M[0, 1])
    new_w = int(h * sin + w * cos)
    new_h = int(h * cos + w * sin)
    
    # Adjust rotation matrix for new bounds
    M[0, 2] += (new_w - w) / 2
    M[1, 2] += (new_h - h) / 2
    
    # Apply rotation with white background
    rotated = cv2.warpAffine(
        image, M, (new_w, new_h),
        borderMode=cv2.BORDER_CONSTANT,
        borderValue=(255, 255, 255) if len(image.shape) == 3 else 255
    )
    
    return rotated, angle


def _detect_angle_hough(gray: np.ndarray) -> float:
    """Detect rotation angle using Hough line transform."""
    # Edge detection
    edges = cv2.Canny(gray, 50, 150, apertureSize=3)
    
    # Dilate to connect nearby edges
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 1))
    edges = cv2.dilate(edges, kernel, iterations=1)
    
    # Detect lines using Hough transform
    lines = cv2.HoughLinesP(
        edges,
        rho=1,
        theta=np.pi / 180,
        threshold=100,
        minLineLength=50,
        maxLineGap=10
    )
    
    if lines is None or len(lines) == 0:
        return 0.0
    
    # Calculate angles of all detected lines
    angles = []
    for line in lines:
        x1, y1, x2, y2 = line[0]
        if x2 - x1 == 0:
            continue
        angle = np.degrees(np.arctan2(y2 - y1, x2 - x1))
        
        # Only consider near-horizontal lines (bands should be horizontal)
        if abs(angle) < 45:
            angles.append(angle)
    
    if not angles:
        return 0.0
    
    # Use median angle to be robust to outliers
    median_angle = np.median(angles)
    
    return median_angle


def _detect_angle_moments(gray: np.ndarray) -> float:
    """Detect rotation angle using image moments."""
    # Threshold to get binary image
    _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    
    # Find contours
    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    if not contours:
        return 0.0
    
    # Get the largest contours (likely bands)
    contours = sorted(contours, key=cv2.contourArea, reverse=True)[:10]
    
    angles = []
    for cnt in contours:
        if cv2.contourArea(cnt) < 100:
            continue
        
        # Fit ellipse or min area rect
        if len(cnt) >= 5:
            ellipse = cv2.fitEllipse(cnt)
            angle = ellipse[2]
            
            # Convert to horizontal reference
            if angle > 45:
                angle = angle - 90
            
            angles.append(angle)
    
    if not angles:
        return 0.0
    
    return np.median(angles)


def group_detections_into_rows(
    detections: List[dict],
    y_tolerance: int = 30,
    image_height: int = None
) -> List[dict]:
    """
    Group individual band detections into horizontal rows.
    
    Bands at similar Y positions are grouped together into a single row annotation.
    
    Args:
        detections: List of individual band detections
        y_tolerance: Maximum Y distance (pixels) to consider bands in same row
        image_height: Image height for tolerance calculation (optional)
        
    Returns:
        List of row detections (merged boxes)
    """
    if not detections:
        return []
    
    # Sort by Y center position
    sorted_dets = sorted(detections, key=lambda d: (d['box'][1] + d['box'][3]) / 2)
    
    rows = []
    current_row = [sorted_dets[0]]
    current_y_center = (sorted_dets[0]['box'][1] + sorted_dets[0]['box'][3]) / 2
    
    for det in sorted_dets[1:]:
        det_y_center = (det['box'][1] + det['box'][3]) / 2
        
        # Check if this detection belongs to current row
        if abs(det_y_center - current_y_center) <= y_tolerance:
            current_row.append(det)
            # Update row center as average
            current_y_center = np.mean([(d['box'][1] + d['box'][3]) / 2 for d in current_row])
        else:
            # Save current row and start new one
            rows.append(current_row)
            current_row = [det]
            current_y_center = det_y_center
    
    # Don't forget last row
    rows.append(current_row)
    
    # Merge each row into single bounding box
    row_detections = []
    for row_idx, row in enumerate(rows):
        # Get bounding box that contains all bands in row
        x1 = min(d['box'][0] for d in row)
        y1 = min(d['box'][1] for d in row)
        x2 = max(d['box'][2] for d in row)
        y2 = max(d['box'][3] for d in row)
        
        # Average score
        avg_score = np.mean([d['score'] for d in row])
        
        row_detections.append({
            'box': [x1, y1, x2, y2],
            'score': float(avg_score),
            'label': f'row_{row_idx}',
            'num_bands': len(row),
            'bands': row  # Keep original bands for reference
        })
    
    return row_detections


class WesternBlotAutoLabeler:
    """Auto-label Western blot images using Grounding DINO."""
    
    def __init__(
        self,
        model_id: str = "IDEA-Research/grounding-dino-tiny",
        device: str = "auto",
        box_threshold: float = 0.25,
        text_threshold: float = 0.20
    ):
        """
        Initialize the auto-labeler.
        
        Args:
            model_id: HuggingFace model ID
            device: 'auto', 'cuda', or 'cpu'
            box_threshold: Confidence threshold for detections
            text_threshold: Text matching threshold
        """
        import torch
        from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
        
        if device == "auto":
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
        
        self.box_threshold = box_threshold
        self.text_threshold = text_threshold
        
        print(f"Loading Grounding DINO...")
        print(f"  Model: {model_id}")
        print(f"  Device: {self.device}")
        
        self.processor = AutoProcessor.from_pretrained(model_id)
        self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id)
        self.model.to(self.device)
        self.model.eval()
        
        print("Ready!\n")
    
    def detect(
        self,
        image: Image.Image,
        text_prompt: str = "dark horizontal band . protein band . blot band"
    ) -> List[dict]:
        """
        Detect bands in an image.
        
        Args:
            image: PIL Image
            text_prompt: Text description for detection (separate with ' . ')
            
        Returns:
            List of detections with boxes and scores
        """
        import torch
        
        # Process image
        inputs = self.processor(images=image, text=text_prompt, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        # Run detection
        with torch.no_grad():
            outputs = self.model(**inputs)
        
        # Post-process
        results = self.processor.post_process_grounded_object_detection(
            outputs,
            inputs["input_ids"],
            box_threshold=self.box_threshold,
            text_threshold=self.text_threshold,
            target_sizes=[image.size[::-1]]  # (height, width)
        )[0]
        
        detections = []
        for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
            box_list = box.cpu().tolist()
            detections.append({
                'box': box_list,  # [x1, y1, x2, y2]
                'score': float(score.cpu()),
                'label': label
            })
        
        return detections
    
    def auto_label_image(
        self,
        image_path: str,
        text_prompt: str = "dark horizontal band . protein band . blot band",
        output_dir: str = None,
        save_visualization: bool = True,
        save_yolo: bool = True,
        save_coco: bool = True,
        mode: str = 'band',  # 'band' or 'row'
        row_tolerance: int = 30,
        auto_rotate: bool = False,
        max_rotation: float = 15.0,
        save_rotated: bool = True
    ) -> dict:
        """
        Auto-label a single image and save annotations.
        
        Args:
            image_path: Path to input image
            text_prompt: Detection prompt
            output_dir: Output directory (default: same as image)
            save_visualization: Save image with drawn boxes
            save_yolo: Save YOLO format annotations
            save_coco: Save COCO format annotations
            mode: 'band' for individual bands, 'row' for grouped rows
            row_tolerance: Y-distance tolerance for row grouping (pixels)
            auto_rotate: Automatically detect and correct image rotation
            max_rotation: Maximum rotation angle to correct (degrees)
            save_rotated: Save the rotated image separately
            
        Returns:
            Dict with detection results
        """
        image_path = Path(image_path)
        
        if output_dir:
            output_dir = Path(output_dir)
            output_dir.mkdir(parents=True, exist_ok=True)
        else:
            output_dir = image_path.parent
        
        # Load image
        image = Image.open(image_path).convert("RGB")
        image_np = np.array(image)
        w, h = image.size
        
        print(f"Processing: {image_path.name} ({w}x{h}) [mode: {mode}]")
        
        # Auto-rotate if requested
        rotation_angle = 0.0
        if auto_rotate:
            image_np, rotation_angle = auto_rotate_image(image_np, max_angle=max_rotation)
            if abs(rotation_angle) > 0.5:
                print(f"  Auto-rotated: {rotation_angle:.1f}°")
                # Update image and dimensions
                image = Image.fromarray(image_np)
                w, h = image.size
                
                # Save rotated image if requested
                if save_rotated:
                    rotated_path = output_dir / f"{image_path.stem}_rotated{image_path.suffix}"
                    image.save(rotated_path)
                    print(f"  Saved rotated: {rotated_path}")
        
        # Detect individual bands first
        detections = self.detect(image, text_prompt)
        print(f"  Found {len(detections)} individual bands")
        
        # Group into rows if requested
        if mode == 'row' and detections:
            detections = group_detections_into_rows(detections, y_tolerance=row_tolerance, image_height=h)
            print(f"  Grouped into {len(detections)} rows")
        
        # Save YOLO format (.txt)
        if save_yolo:
            yolo_lines = []
            for det in detections:
                x1, y1, x2, y2 = det['box']
                # Convert to YOLO format (normalized center + dimensions)
                x_center = ((x1 + x2) / 2) / w
                y_center = ((y1 + y2) / 2) / h
                box_w = (x2 - x1) / w
                box_h = (y2 - y1) / h
                yolo_lines.append(f"0 {x_center:.6f} {y_center:.6f} {box_w:.6f} {box_h:.6f}")
            
            suffix = '_rows' if mode == 'row' else ''
            yolo_path = output_dir / f"{image_path.stem}{suffix}.txt"
            with open(yolo_path, 'w') as f:
                f.write('\n'.join(yolo_lines))
            print(f"  Saved: {yolo_path}")
        
        # Save COCO format (.json)
        if save_coco:
            category_name = 'row' if mode == 'row' else 'band'
            coco_data = {
                'image': {
                    'file_name': image_path.name,
                    'width': w,
                    'height': h
                },
                'annotations': [
                    {
                        'id': i,
                        'bbox': [det['box'][0], det['box'][1], 
                                det['box'][2] - det['box'][0],  # width
                                det['box'][3] - det['box'][1]], # height
                        'area': (det['box'][2] - det['box'][0]) * (det['box'][3] - det['box'][1]),
                        'category_id': 0,
                        'category_name': category_name,
                        'score': det['score'],
                        'iscrowd': 0,
                        'num_bands': det.get('num_bands', 1)  # For row mode
                    }
                    for i, det in enumerate(detections)
                ],
                'categories': [{'id': 0, 'name': category_name}],
                'mode': mode
            }
            
            suffix = '_rows' if mode == 'row' else ''
            json_path = output_dir / f"{image_path.stem}{suffix}.json"
            with open(json_path, 'w') as f:
                json.dump(coco_data, f, indent=2)
            print(f"  Saved: {json_path}")
        
        # Save visualization
        if save_visualization:
            vis_image = self._draw_detections(np.array(image), detections, mode=mode)
            suffix = '_rows' if mode == 'row' else ''
            vis_path = output_dir / f"{image_path.stem}{suffix}_auto_labeled.png"
            cv2.imwrite(str(vis_path), cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR))
            print(f"  Saved: {vis_path}")
        
        return {
            'image_path': str(image_path),
            'num_detections': len(detections),
            'detections': detections,
            'mode': mode,
            'rotation_angle': rotation_angle
        }
    
    def auto_label_directory(
        self,
        input_dir: str,
        output_dir: str = None,
        text_prompt: str = "dark horizontal band . protein band . blot band",
        save_visualization: bool = True,
        mode: str = 'band',
        row_tolerance: int = 30,
        auto_rotate: bool = False,
        max_rotation: float = 15.0,
        save_rotated: bool = True
    ) -> dict:
        """
        Auto-label all images in a directory.
        
        Args:
            input_dir: Input directory with images
            output_dir: Output directory for annotations
            text_prompt: Detection prompt
            save_visualization: Save visualization images
            mode: 'band' for individual bands, 'row' for grouped rows
            row_tolerance: Y-distance tolerance for row grouping
            auto_rotate: Automatically detect and correct image rotation
            max_rotation: Maximum rotation angle to correct (degrees)
            save_rotated: Save rotated images separately
            
        Returns:
            Summary dict
        """
        input_dir = Path(input_dir)
        output_dir = Path(output_dir) if output_dir else input_dir / 'labels'
        output_dir.mkdir(parents=True, exist_ok=True)
        
        # Find images
        image_extensions = {'.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp', '.gif'}
        image_paths = [p for p in input_dir.iterdir() 
                      if p.suffix.lower() in image_extensions]
        
        print(f"Found {len(image_paths)} images in {input_dir}")
        print(f"Mode: {mode}" + (f" (tolerance: {row_tolerance}px)" if mode == 'row' else ''))
        if auto_rotate:
            print(f"Auto-rotate: enabled (max: {max_rotation}°)")
        print()
        
        total_detections = 0
        results = []
        
        for i, img_path in enumerate(image_paths, 1):
            print(f"[{i}/{len(image_paths)}] ", end='')
            result = self.auto_label_image(
                str(img_path),
                text_prompt=text_prompt,
                output_dir=str(output_dir),
                save_visualization=save_visualization,
                mode=mode,
                row_tolerance=row_tolerance,
                auto_rotate=auto_rotate,
                max_rotation=max_rotation,
                save_rotated=save_rotated
            )
            total_detections += result['num_detections']
            results.append(result)
            print()
        
        # Create combined COCO dataset
        self._save_combined_coco(results, output_dir, mode)
        
        label_type = "rows" if mode == 'row' else "bands"
        print(f"\n{'='*50}")
        print(f"Auto-labeling complete!")
        print(f"  Images processed: {len(image_paths)}")
        print(f"  Total {label_type}: {total_detections}")
        print(f"  Average per image: {total_detections/len(image_paths):.1f}")
        print(f"  Output: {output_dir}")
        print(f"\nNext steps:")
        print(f"  1. Review labels in CVAT, Roboflow, or Label Studio")
        print(f"  2. Correct any errors")
        print(f"  3. Export and train your model")
        print(f"{'='*50}")
        
        return {
            'num_images': len(image_paths),
            'total_detections': total_detections,
            'output_dir': str(output_dir),
            'mode': mode
        }
    
    def _draw_detections(self, image: np.ndarray, detections: List[dict], mode: str = 'band') -> np.ndarray:
        """Draw detection boxes on image."""
        vis = image.copy()
        
        for i, det in enumerate(detections):
            x1, y1, x2, y2 = [int(v) for v in det['box']]
            score = det['score']
            
            # Different colors for row vs band mode
            if mode == 'row':
                # Use distinct colors for each row
                colors = [
                    (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),
                    (255, 0, 255), (0, 255, 255), (128, 0, 255), (255, 128, 0)
                ]
                color = colors[i % len(colors)]
                label = f"Row {i+1} ({det.get('num_bands', 1)} bands)"
            else:
                # Color based on confidence for bands
                if score > 0.5:
                    color = (0, 255, 0)  # Green
                elif score > 0.3:
                    color = (0, 255, 255)  # Yellow
                else:
                    color = (0, 165, 255)  # Orange
                label = f"{score:.2f}"
            
            # Draw box
            thickness = 3 if mode == 'row' else 2
            cv2.rectangle(vis, (x1, y1), (x2, y2), color, thickness)
            
            # Draw label background
            (label_w, label_h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
            cv2.rectangle(vis, (x1, y1 - label_h - 6), (x1 + label_w + 4, y1), color, -1)
            cv2.putText(vis, label, (x1 + 2, y1 - 4),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
        
        return vis
    
    def _save_combined_coco(self, results: List[dict], output_dir: Path, mode: str = 'band'):
        """Save combined COCO format annotations for all images."""
        images = []
        annotations = []
        ann_id = 0
        
        for img_id, result in enumerate(results):
            img_path = Path(result['image_path'])
            
            images.append({
                'id': img_id,
                'file_name': img_path.name,
            })
            
            for det in result['detections']:
                x1, y1, x2, y2 = det['box']
                annotations.append({
                    'id': ann_id,
                    'image_id': img_id,
                    'category_id': 0,
                    'bbox': [x1, y1, x2 - x1, y2 - y1],
                    'area': (x2 - x1) * (y2 - y1),
                    'iscrowd': 0,
                    'score': det['score'],
                    'num_bands': det.get('num_bands', 1)
                })
                ann_id += 1
        
        category_name = 'row' if mode == 'row' else 'band'
        coco_dataset = {
            'images': images,
            'annotations': annotations,
            'categories': [{'id': 0, 'name': category_name, 'supercategory': 'western_blot'}],
            'mode': mode
        }
        
        suffix = '_rows' if mode == 'row' else ''
        coco_path = output_dir / f'annotations_coco{suffix}.json'
        with open(coco_path, 'w') as f:
            json.dump(coco_dataset, f, indent=2)
        print(f"Saved combined COCO: {coco_path}")


def main():
    parser = argparse.ArgumentParser(
        description='Auto-label Western Blot images using AI (Grounding DINO)',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Single image - annotate individual bands
  python auto_label.py --image western.png
  
  # Single image - annotate by ROWS (groups bands horizontally)
  python auto_label.py --image western.png --mode row
  
  # Auto-rotate tilted images before labeling
  python auto_label.py --image western.png --auto-rotate
  
  # Folder of images with auto-rotation
  python auto_label.py --input-dir ./blots --output-dir ./labels --auto-rotate
  
  # Row mode with custom tolerance (pixels)
  python auto_label.py --image western.png --mode row --row-tolerance 50
  
  # Detect more bands (lower threshold = more detections)
  python auto_label.py --image western.png --threshold 0.15
  
  # Custom detection prompt
  python auto_label.py --image western.png --prompt "dark band . protein . blot"

Auto-Rotation:
  --auto-rotate    Detect and correct image tilt (bands should be horizontal)
  --max-rotation   Maximum degrees to correct (default: 15)
  --no-save-rotated  Don't save the rotated image separately

Annotation Modes:
  - band (default): One box per individual band
  - row: One box per horizontal row (groups bands at similar Y positions)

Output formats:
  - YOLO: image_name.txt (one line per box: class x_center y_center width height)
  - COCO: image_name.json (full annotation with boxes, scores)
  - Visualization: image_name_auto_labeled.png

After auto-labeling:
  1. Import into CVAT, Roboflow, or Label Studio
  2. Review and correct the labels
  3. Export in your preferred format
  4. Train your model with the corrected labels
        """
    )
    
    # Input options
    input_group = parser.add_mutually_exclusive_group(required=True)
    input_group.add_argument('--image', type=str, help='Single image to label')
    input_group.add_argument('--input-dir', type=str, help='Directory of images')
    
    # Output options
    parser.add_argument('--output-dir', type=str, default=None,
                        help='Output directory (default: ./labels or same as input)')
    
    # Annotation mode
    parser.add_argument('--mode', type=str, default='band', choices=['band', 'row'],
                        help='Annotation mode: "band" for individual bands, "row" for grouped rows')
    parser.add_argument('--row-tolerance', type=int, default=30,
                        help='Y-distance tolerance for row grouping in pixels (default: 30)')
    
    # Auto-rotation options
    parser.add_argument('--auto-rotate', action='store_true',
                        help='Automatically detect and correct image rotation')
    parser.add_argument('--max-rotation', type=float, default=15.0,
                        help='Maximum rotation angle to correct in degrees (default: 15)')
    parser.add_argument('--no-save-rotated', action='store_true',
                        help='Do not save the rotated image separately')
    
    # Detection options
    parser.add_argument('--prompt', type=str,
                        default="dark horizontal band . protein band . blot band",
                        help='Detection prompt (separate terms with " . ")')
    parser.add_argument('--threshold', type=float, default=0.25,
                        help='Detection confidence threshold (default: 0.25, lower=more detections)')
    parser.add_argument('--text-threshold', type=float, default=0.20,
                        help='Text matching threshold (default: 0.20)')
    
    # Other options
    parser.add_argument('--device', type=str, default='auto',
                        choices=['auto', 'cuda', 'cpu'])
    parser.add_argument('--no-vis', action='store_true',
                        help='Skip saving visualization images')
    parser.add_argument('--model', type=str, default='IDEA-Research/grounding-dino-tiny',
                        help='HuggingFace model ID')
    
    args = parser.parse_args()
    
    # Initialize labeler
    labeler = WesternBlotAutoLabeler(
        model_id=args.model,
        device=args.device,
        box_threshold=args.threshold,
        text_threshold=args.text_threshold
    )
    
    # Process
    if args.image:
        output_dir = args.output_dir or './labels'
        labeler.auto_label_image(
            args.image,
            text_prompt=args.prompt,
            output_dir=output_dir,
            save_visualization=not args.no_vis,
            mode=args.mode,
            row_tolerance=args.row_tolerance,
            auto_rotate=args.auto_rotate,
            max_rotation=args.max_rotation,
            save_rotated=not args.no_save_rotated
        )
    else:
        output_dir = args.output_dir or str(Path(args.input_dir) / 'labels')
        labeler.auto_label_directory(
            args.input_dir,
            output_dir=output_dir,
            text_prompt=args.prompt,
            save_visualization=not args.no_vis,
            mode=args.mode,
            row_tolerance=args.row_tolerance,
            auto_rotate=args.auto_rotate,
            max_rotation=args.max_rotation,
            save_rotated=not args.no_save_rotated
        )


if __name__ == "__main__":
    main()
