"""
Kafka Consumer Service for processing delayed payin credit events
Optimized for high throughput (1 crore transactions/minute)
"""

import json
import logging
import time
import sys
from typing import Dict, Any, List
from decimal import Decimal
from datetime import datetime
from confluent_kafka import Consumer, KafkaException, KafkaError, TopicPartition
from sqlalchemy.orm import Session
from app.core.config import settings
from app.services.wallet_service import WalletService
from app.repositories.payin_transaction_repository import PayinTransactionRepository
from app.common.enums import TransactionType
from app.core.database import SessionLocal

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


class KafkaConsumerService:
    """Kafka consumer service for processing delayed payin credit events"""
    
    def __init__(self):
        """Initialize Kafka consumer with optimized settings for high throughput"""
        self.consumer_config = {
            'bootstrap.servers': settings.KAFKA_BOOTSTRAP_SERVERS,
            'group.id': settings.KAFKA_CONSUMER_GROUP_ID,
            'auto.offset.reset': 'earliest',  # Start from beginning if no committed offset
            'enable.auto.commit': False,  # Manual commit for better control
            'session.timeout.ms': 30000,
            'heartbeat.interval.ms': 10000,
            'max.poll.interval.ms': 300000,  # 5 minutes - allow longer processing time
        }
        self.consumer = Consumer(self.consumer_config)
        # Subscribe to topic - will retry if topic doesn't exist yet
        try:
            self.consumer.subscribe([settings.KAFKA_PAYIN_CREDIT_TOPIC])
        except Exception as e:
            logger.warning(f"Topic {settings.KAFKA_PAYIN_CREDIT_TOPIC} may not exist yet: {e}")
            # Subscribe anyway - Kafka will handle topic creation
            self.consumer.subscribe([settings.KAFKA_PAYIN_CREDIT_TOPIC])
        self.running = False
        self.batch_size = 100  # Process in batches for efficiency
    
    def process_message(self, message: Dict[str, Any], db: Session) -> bool:
        """
        Process a single payin credit event
        
        Args:
            message: Decoded Kafka message
            db: Database session
            
        Returns:
            True if processed successfully, False otherwise
        """
        try:
            payin_txn_id = message.get('payin_txn_id')
            user_id = message.get('user_id')
            txnid = message.get('txnid')
            amount = message.get('amount')
            target_timestamp = message.get('target_timestamp')
            
            # Check if enough time has passed
            current_timestamp = time.time()
            if target_timestamp is None or target_timestamp == 0:
                logger.error(f"Invalid target_timestamp for txnid {txnid}: {target_timestamp}")
                return False
            if current_timestamp < target_timestamp:
                # Not yet time to process, skip for now
                logger.debug(f"Event for txnid {txnid} not ready yet. Target: {target_timestamp}, Current: {current_timestamp}")
                return False
            
            # Get payin transaction
            payin_repo = PayinTransactionRepository(db)
            payin_txn = payin_repo.get_by_id(str(payin_txn_id))
            
            if not payin_txn:
                logger.error(f"Payin transaction {payin_txn_id} not found")
                return False
            
            # Check if already credited (idempotency check)
            from app.repositories.wallet_transaction_repository import WalletTransactionRepository
            wallet_txn_repo = WalletTransactionRepository(db)
            existing_wallet_txn = wallet_txn_repo.get_by_txnid(txnid)
            
            if existing_wallet_txn:
                logger.info(f"Wallet already credited for txnid {txnid}, skipping")
                return True
            
            # Check if transaction is still in success status
            if payin_txn.status != "success":
                logger.warning(f"Payin transaction {payin_txn_id} status is {payin_txn.status}, not crediting wallet")
                return False
            
            if payin_txn.refunded:
                logger.warning(f"Payin transaction {payin_txn_id} is refunded, not crediting wallet")
                return False
            
            # Credit the wallet
            # IMPORTANT: We should credit net amount (amount - charge), not full amount
            # Use the authoritative values from the payin transaction, not the Kafka message
            try:
                gross_amount = Decimal(payin_txn.amount)
                charge_amount = Decimal(payin_txn.charge or 0)
            except Exception as e:
                logger.error(f"Invalid amount/charge on payin_txn {payin_txn_id}: {e}")
                return False

            net_amount = gross_amount - charge_amount
            if net_amount <= 0:
                logger.error(
                    f"Computed net_amount <= 0 for payin_txn {payin_txn_id}: "
                    f"gross={gross_amount}, charge={charge_amount}"
                )
                return False

            wallet_service = WalletService(db)
            wallet_txn = wallet_service.process_transaction(
                user_id=payin_txn.user_id,
                txn_type="payin",
                transaction_type=TransactionType.CREDIT.value,
                amount=net_amount,
                txnid=txnid,
                db=db,
                update_main=True
            )
            
            # Update settled status to true
            payin_repo.update(
                str(payin_txn_id),
                settled=True
            )
            
            logger.info(f"Successfully credited wallet for payin txnid {txnid}, amount: {amount}, settled: true")
            return True
            
        except Exception as e:
            logger.error(f"Error processing payin credit event: {e}", exc_info=True)
            return False
    
    def process_batch(self, messages: List[Any], db: Session) -> tuple[int, List[Any]]:
        """
        Process a batch of messages
        
        Args:
            messages: List of Kafka messages
            db: Database session
            
        Returns:
            Tuple of (processed_count, messages_to_commit)
            - processed_count: Number of successfully processed messages
            - messages_to_commit: List of messages that should be committed
        """
        processed_count = 0
        ready_messages = []
        invalid_messages = []
        not_ready_messages = []  # Track messages that need to be checked again
        
        # Separate messages into ready and not ready
        current_timestamp = time.time()
        for msg in messages:
            try:
                message_data = json.loads(msg.value().decode('utf-8'))
                target_timestamp = message_data.get('target_timestamp', 0)
                
                # Skip messages with invalid target_timestamp
                if target_timestamp is None or target_timestamp == 0:
                    logger.error(f"Invalid target_timestamp in message: {message_data}")
                    invalid_messages.append(msg)
                    continue
                
                if current_timestamp >= target_timestamp:
                    ready_messages.append((msg, message_data))
                    logger.info(f"Message for txnid {message_data.get('txnid')} is ready to process (target: {target_timestamp}, current: {current_timestamp})")
                else:
                    # Message not ready yet - we need to seek back to it so it's returned again
                    wait_time = target_timestamp - current_timestamp
                    if wait_time < 60:  # Only log if less than 1 minute away
                        logger.debug(f"Message for txnid {message_data.get('txnid')} not ready yet. Wait time: {wait_time:.2f} seconds")
                    # Store message info for seeking back
                    not_ready_messages.append((msg, message_data, target_timestamp))
            except Exception as e:
                logger.error(f"Error decoding message: {e}")
                # Invalid messages should be committed to skip them
                invalid_messages.append(msg)
        
        messages_to_commit = invalid_messages.copy()
        
        # Process ready messages
        for msg, message_data in ready_messages:
            try:
                if self.process_message(message_data, db):
                    processed_count += 1
                    messages_to_commit.append(msg)  # Commit successfully processed messages
                else:
                    # Processing failed but message is valid - commit to avoid infinite retries
                    messages_to_commit.append(msg)
                    processed_count += 1
            except Exception as e:
                logger.error(f"Error processing message: {e}")
                # Commit to avoid reprocessing the same message
                messages_to_commit.append(msg)
                processed_count += 1
        
        # CRITICAL FIX: For messages that aren't ready, seek back to them
        # This ensures Kafka returns them on the next poll
        for msg, message_data, target_timestamp in not_ready_messages:
            try:
                # Seek back to this message's offset so it will be returned again
                partition = msg.partition()
                offset = msg.offset()
                # Seek to the current offset (this message) so it's returned on next poll
                tp = TopicPartition(msg.topic(), partition, offset)
                self.consumer.seek(tp)
                logger.debug(f"Seeked back to offset {offset} in partition {partition} for txnid {message_data.get('txnid')} (will check again in {target_timestamp - time.time():.1f}s)")
            except Exception as e:
                logger.error(f"Error seeking back for message {message_data.get('txnid')}: {e}")
                # If seek fails, don't commit - message might be lost but that's better than infinite loop
        
        return processed_count, messages_to_commit
    
    def run(self):
        """Run the consumer loop"""
        self.running = True
        logger.info("Starting Kafka consumer for payin credit events")
        
        db = SessionLocal()
        try:
            while self.running:
                try:
                    # Poll for messages (non-blocking)
                    # consume() returns a list of messages, timeout in seconds
                    messages = self.consumer.consume(
                        num_messages=self.batch_size,
                        timeout=0.5  # Poll every 500ms to catch ready messages quickly
                    )
                    
                    if not messages:
                        continue
                    
                    logger.info(f"Received {len(messages)} messages from Kafka")
                    
                    # Filter out errors
                    valid_messages = []
                    for msg in messages:
                        if msg.error():
                            if msg.error().code() == KafkaError._PARTITION_EOF:
                                # End of partition - this is normal, just means we've read all messages in this partition
                                # Continue to next message
                                continue
                            elif msg.error().code() == KafkaError.UNKNOWN_TOPIC_OR_PART:
                                # Topic doesn't exist yet, wait and retry
                                logger.warning(f"Topic {settings.KAFKA_PAYIN_CREDIT_TOPIC} not available yet, waiting...")
                                time.sleep(2)  # Wait before retrying
                                continue
                            else:
                                logger.error(f"Consumer error: {msg.error()}")
                                continue
                        # Valid message
                        valid_messages.append(msg)
                    
                    if not valid_messages:
                        continue
                    
                    # Process batch
                    logger.info(f"Processing {len(valid_messages)} valid messages")
                    processed, messages_to_commit = self.process_batch(valid_messages, db)
                    
                    # Commit offsets only for messages that were processed or invalid
                    if messages_to_commit:
                        # Group messages by partition and commit the latest offset in each partition
                        partitions_to_commit = {}
                        for msg in messages_to_commit:
                            partition = msg.partition()
                            offset = msg.offset()
                            if partition not in partitions_to_commit or offset > partitions_to_commit[partition].offset():
                                partitions_to_commit[partition] = msg
                        
                        # Commit offsets
                        for partition, msg in partitions_to_commit.items():
                            try:
                                self.consumer.commit(message=msg, asynchronous=False)
                            except Exception as e:
                                logger.error(f"Error committing offset for partition {partition}: {e}")
                        
                        if processed > 0:
                            logger.info(f"Processed {processed} messages, committed offsets for {len(partitions_to_commit)} partitions")
                    
                except KeyboardInterrupt:
                    logger.info("Consumer interrupted by user")
                    break
                except Exception as e:
                    logger.error(f"Error in consumer loop: {e}", exc_info=True)
                    time.sleep(1)  # Brief pause before retrying
                    
        finally:
            db.close()
            self.close()
    
    def close(self):
        """Close the consumer"""
        self.running = False
        if self.consumer:
            self.consumer.close()
        logger.info("Kafka consumer closed")


def main():
    """Main entry point for the consumer"""
    consumer = KafkaConsumerService()
    try:
        consumer.run()
    except KeyboardInterrupt:
        logger.info("Shutting down consumer...")
    finally:
        consumer.close()


if __name__ == "__main__":
    main()
