"""
Recursively copy all images from a folder to an output folder.

Usage:
    python collect_images.py --input ./data --output ./all_images
    
    # Flatten structure (no subfolders in output)
    python collect_images.py --input ./data --output ./all_images --flatten
    
    # Only specific formats
    python collect_images.py --input ./data --output ./all_images --formats png jpg
"""

import argparse
import shutil
from pathlib import Path
from typing import List, Set, Optional


def collect_images(
    input_dir: str,
    output_dir: str,
    formats: Optional[Set[str]] = None,
    flatten: bool = True,
    copy: bool = True,
    verbose: bool = True
) -> List[Path]:
    """
    Recursively collect all images from input_dir to output_dir.
    
    Args:
        input_dir: Source directory to search recursively
        output_dir: Destination directory for images
        formats: Set of extensions to include (e.g., {'.png', '.jpg'})
                 Default: common image formats
        flatten: If True, copy all images to output root (no subfolders)
                 If False, preserve directory structure
        copy: If True, copy files. If False, move files.
        verbose: Print progress
        
    Returns:
        List of paths to copied images in output_dir
    """
    input_dir = Path(input_dir)
    output_dir = Path(output_dir)
    
    if not input_dir.exists():
        raise ValueError(f"Input directory does not exist: {input_dir}")
    
    # Default image formats
    if formats is None:
        formats = {'.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp', '.gif', '.webp'}
    else:
        # Ensure lowercase with dots
        formats = {f'.{f.lower().lstrip(".")}' for f in formats}
    
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Find all images recursively
    images = []
    for ext in formats:
        images.extend(input_dir.rglob(f'*{ext}'))
        images.extend(input_dir.rglob(f'*{ext.upper()}'))
    
    # Remove duplicates and sort
    images = sorted(set(images))
    
    if verbose:
        print(f"Found {len(images)} images in {input_dir}")
    
    copied = []
    name_counts = {}  # For handling duplicates in flatten mode
    
    for img_path in images:
        if flatten:
            # All images go to output root
            name = img_path.name
            
            # Handle duplicate names
            if name in name_counts:
                name_counts[name] += 1
                stem = img_path.stem
                suffix = img_path.suffix
                name = f"{stem}_{name_counts[name]}{suffix}"
            else:
                name_counts[name] = 0
            
            dest = output_dir / name
        else:
            # Preserve relative directory structure
            rel_path = img_path.relative_to(input_dir)
            dest = output_dir / rel_path
            dest.parent.mkdir(parents=True, exist_ok=True)
        
        # Copy or move
        if copy:
            shutil.copy2(img_path, dest)
        else:
            shutil.move(img_path, dest)
        
        copied.append(dest)
        
        if verbose:
            print(f"  {'Copied' if copy else 'Moved'}: {img_path.name} -> {dest}")
    
    if verbose:
        print(f"\nDone! {len(copied)} images in {output_dir}")
    
    return copied


def main():
    parser = argparse.ArgumentParser(
        description='Recursively copy all images to output folder',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  python collect_images.py --input ./data --output ./all_images
  python collect_images.py --input ./data --output ./all_images --flatten
  python collect_images.py --input ./data --output ./all_images --formats png jpg tif
  python collect_images.py --input ./data --output ./all_images --move
        """
    )
    
    parser.add_argument('--input', '-i', type=str, required=True,
                        help='Input directory to search recursively')
    parser.add_argument('--output', '-o', type=str, required=True,
                        help='Output directory for images')
    parser.add_argument('--formats', '-f', nargs='+', default=None,
                        help='Image formats to include (default: png jpg jpeg tif tiff bmp gif webp)')
    parser.add_argument('--flatten', action='store_true',
                        help='Put all images in output root (no subfolders)')
    parser.add_argument('--keep-structure', action='store_true',
                        help='Preserve directory structure (opposite of --flatten)')
    parser.add_argument('--move', action='store_true',
                        help='Move files instead of copying')
    parser.add_argument('--quiet', '-q', action='store_true',
                        help='Suppress output')
    
    args = parser.parse_args()
    
    # Determine flatten mode
    flatten = True  # Default
    if args.keep_structure:
        flatten = False
    if args.flatten:
        flatten = True
    
    formats = set(args.formats) if args.formats else None
    
    collect_images(
        input_dir=args.input,
        output_dir=args.output,
        formats=formats,
        flatten=flatten,
        copy=not args.move,
        verbose=not args.quiet
    )


if __name__ == "__main__":
    main()
