"""
Document Compression Service
Handles compression of images and PDFs for KYC documents and other use cases
"""
import os
import io
from typing import Tuple, Optional
from PIL import Image
import logging

logger = logging.getLogger(__name__)


class DocumentCompressionService:
    """Service for compressing documents (images and PDFs)"""
    
    # Maximum width for images
    MAX_IMAGE_WIDTH = 500
    
    def __init__(self):
        pass
    
    def compress_image(
        self,
        file_content: bytes,
        max_width: int = MAX_IMAGE_WIDTH,
        quality: int = 85,
        format: Optional[str] = None,
        convert_to_webp: bool = False
    ) -> Tuple[bytes, str]:
        """
        Compress an image file
        
        Args:
            file_content: Original image file content as bytes
            max_width: Maximum width in pixels (default: 500)
            quality: Image quality (1-100, default: 85)
            format: Output format ('JPEG', 'PNG', 'WEBP'). If None, converts to WebP for better compression
            convert_to_webp: If True (default), converts all images to WebP format for optimal compression
            
        Returns:
            Tuple of (compressed_bytes, mime_type)
            
        Raises:
            ValueError: If image cannot be processed
        """
        try:
            # Open image from bytes
            image = Image.open(io.BytesIO(file_content))
            
            # Remove EXIF data (metadata) by creating a new image without metadata
            # Preserve transparency if present
            if image.mode in ('RGBA', 'LA', 'P'):
                # Convert to RGBA to preserve transparency
                if image.mode == 'P':
                    image = image.convert('RGBA')
                else:
                    # Create new image without EXIF
                    data = list(image.getdata())
                    image = Image.new(image.mode, image.size)
                    image.putdata(data)
            elif image.mode not in ('RGB', 'L'):
                image = image.convert('RGB')
            
            # Get original format
            original_format = image.format or 'JPEG'
            
            # Determine output format
            if format is not None:
                output_format = format.upper()
            elif convert_to_webp:
                # Convert all images to WebP for better compression
                # WebP supports both lossy and lossless, and transparency
                output_format = 'WEBP'
            else:
                # Legacy behavior: preserve original format with some optimizations
                if original_format == 'PNG' and image.mode not in ('RGBA', 'LA', 'P'):
                    output_format = 'JPEG'
                    image = image.convert('RGB')
                elif original_format in ('JPEG', 'JPG'):
                    output_format = 'JPEG'
                    if image.mode != 'RGB':
                        image = image.convert('RGB')
                elif original_format == 'WEBP':
                    output_format = 'WEBP'
                else:
                    # Default to WebP for better compression
                    output_format = 'WEBP'
            
            # Ensure proper color mode for output format
            if output_format == 'JPEG':
                if image.mode != 'RGB':
                    image = image.convert('RGB')
            elif output_format == 'WEBP':
                # WebP supports RGBA (with transparency) and RGB
                # Keep current mode if it's RGBA, RGB, or L, otherwise convert
                if image.mode not in ('RGBA', 'RGB', 'L'):
                    if image.mode in ('LA', 'P'):
                        image = image.convert('RGBA')
                    else:
                        image = image.convert('RGB')
            
            # Calculate new dimensions maintaining aspect ratio
            width, height = image.size
            if width > max_width:
                ratio = max_width / width
                new_height = int(height * ratio)
                # Use high-quality resampling
                image = image.resize((max_width, new_height), Image.Resampling.LANCZOS)
            
            # Optimize image
            output = io.BytesIO()
            
            # Save with optimization
            if output_format == 'WEBP':
                # WebP provides best compression with good quality
                # IMPORTANT: Convert RGBA/P mode to RGB before saving as WebP
                # This ensures proper WebP binary format
                if image.mode in ("RGBA", "P"):
                    image = image.convert("RGB")
                elif image.mode not in ("RGB", "L"):
                    image = image.convert("RGB")
                
                # Save as real WebP binary
                image.save(
                    output,
                    format='WEBP',
                    quality=quality,
                    method=6  # Best quality method (0-6, 6 is best)
                )
                mime_type = 'image/webp'
            elif output_format == 'JPEG':
                image.save(
                    output,
                    format='JPEG',
                    quality=quality,
                    optimize=True,
                    progressive=True  # Progressive JPEG for better compression
                )
                mime_type = 'image/jpeg'
            elif output_format == 'PNG':
                # PNG compression
                image.save(
                    output,
                    format='PNG',
                    optimize=True,
                    compress_level=9  # Maximum compression
                )
                mime_type = 'image/png'
            else:
                # Fallback to WebP for best compression
                image.save(
                    output,
                    format='WEBP',
                    quality=quality,
                    method=6
                )
                mime_type = 'image/webp'
            
            compressed_bytes = output.getvalue()
            
            # Log compression stats
            original_size = len(file_content)
            compressed_size = len(compressed_bytes)
            compression_ratio = (1 - compressed_size / original_size) * 100 if original_size > 0 else 0
            
            logger.info(
                f"Image compressed: {original_size} bytes -> {compressed_size} bytes "
                f"({compression_ratio:.1f}% reduction)"
            )
            
            return compressed_bytes, mime_type
            
        except Exception as e:
            logger.error(f"Error compressing image: {str(e)}")
            raise ValueError(f"Failed to compress image: {str(e)}")
    
    def compress_pdf(
        self,
        file_content: bytes,
        quality: str = 'screen'  # 'screen', 'ebook', 'printer', 'prepress'
    ) -> bytes:
        """
        Compress a PDF file
        
        Args:
            file_content: Original PDF file content as bytes
            quality: Compression quality ('screen', 'ebook', 'printer', 'prepress')
            
        Returns:
            Compressed PDF as bytes
            
        Raises:
            ValueError: If PDF cannot be processed
        """
        try:
            # Try using PyPDF2 or pypdf
            try:
                from pypdf import PdfWriter, PdfReader
            except ImportError:
                try:
                    from PyPDF2 import PdfWriter, PdfReader
                except ImportError:
                    raise ValueError("PyPDF2 or pypdf library is required for PDF compression")
            
            # Read PDF from bytes
            pdf_reader = PdfReader(io.BytesIO(file_content))
            pdf_writer = PdfWriter()
            
            # Copy pages with compression
            for page in pdf_reader.pages:
                # Compress page content
                page.compress_content_streams()
                pdf_writer.add_page(page)
            
            # Set compression based on quality
            if quality == 'screen':
                # Low quality, high compression
                for page in pdf_writer.pages:
                    page.compress_content_streams()
            elif quality == 'ebook':
                # Medium quality
                for page in pdf_writer.pages:
                    page.compress_content_streams()
            elif quality in ('printer', 'prepress'):
                # High quality, less compression
                for page in pdf_writer.pages:
                    page.compress_content_streams()
            
            # Write to bytes
            output = io.BytesIO()
            pdf_writer.write(output)
            compressed_bytes = output.getvalue()
            
            # Log compression stats
            original_size = len(file_content)
            compressed_size = len(compressed_bytes)
            compression_ratio = (1 - compressed_size / original_size) * 100 if original_size > 0 else 0
            
            logger.info(
                f"PDF compressed: {original_size} bytes -> {compressed_size} bytes "
                f"({compression_ratio:.1f}% reduction)"
            )
            
            return compressed_bytes
            
        except Exception as e:
            logger.error(f"Error compressing PDF: {str(e)}")
            # If compression fails, return original
            logger.warning("PDF compression failed, returning original file")
            return file_content
    
    def compress_document(
        self,
        file_content: bytes,
        mime_type: Optional[str] = None,
        filename: Optional[str] = None,
        max_image_width: int = MAX_IMAGE_WIDTH
    ) -> Tuple[bytes, str]:
        """
        Compress a document (image or PDF) based on its type
        
        Args:
            file_content: Original file content as bytes
            mime_type: MIME type of the file (optional, will be detected if not provided)
            filename: Original filename (optional, used for format detection)
            max_image_width: Maximum width for images (default: 500)
            
        Returns:
            Tuple of (compressed_bytes, mime_type)
        """
        # Detect file type
        if mime_type:
            detected_type = mime_type
        elif filename:
            # Try to detect from extension
            ext = os.path.splitext(filename)[1].lower()
            type_map = {
                '.jpg': 'image/jpeg',
                '.jpeg': 'image/jpeg',
                '.png': 'image/png',
                '.webp': 'image/webp',
                '.gif': 'image/gif',
                '.pdf': 'application/pdf'
            }
            detected_type = type_map.get(ext, 'application/octet-stream')
        else:
            # Try to detect from content
            try:
                image = Image.open(io.BytesIO(file_content))
                detected_type = f'image/{image.format.lower()}' if image.format else 'image/jpeg'
            except:
                # Check if it's a PDF
                if file_content[:4] == b'%PDF':
                    detected_type = 'application/pdf'
                else:
                    detected_type = 'application/octet-stream'
        
        # Compress based on type
        if detected_type.startswith('image/'):
            # Convert all images to WebP for better compression and quality
            compressed_bytes, output_mime_type = self.compress_image(
                file_content,
                max_width=max_image_width,
                convert_to_webp=False  # Convert to WebP for optimal compression
            )
            return compressed_bytes, output_mime_type
        elif detected_type == 'application/pdf':
            compressed_bytes = self.compress_pdf(file_content)
            return compressed_bytes, 'application/pdf'
        else:
            # Unknown type, return as-is
            logger.warning(f"Unknown file type: {detected_type}, returning original")
            return file_content, detected_type

