#!/usr/bin/env python3
"""
Production Voice Agent - Direct Context Approach
Full schedule in context (simplest & most accurate!)
No RAG, No Neo4j - just pure LLM with 128K context
"""

import asyncio
import os
import sys
import json
from pathlib import Path
from typing import Optional

# Core imports
import websockets
from loguru import logger
from dotenv import load_dotenv

# Local model imports
from faster_whisper import WhisperModel
from openai import AsyncOpenAI
import subprocess
import tempfile
import wave
import numpy as np
from scipy import signal
from faster_asr import MyAsrModel
from vieneu import Vieneu
# Load environment
load_dotenv()


# ============================================================================
# CONFIGURATION
# ============================================================================

class Config:
    """Configuration for local models"""
    
    # Faster-Whisper ASR
    WHISPER_MODEL = os.getenv("WHISPER_MODEL", "medium") #tiny,base,small,medium,large-v2,large-v3
    WHISPER_DEVICE = os.getenv("WHISPER_DEVICE", "cuda")
    WHISPER_COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "float16")
    WHISPER_LANGUAGE = os.getenv("WHISPER_LANGUAGE", "en")
    
    # vLLM
    VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1")
    VLLM_MODEL = os.getenv("VLLM_MODEL", "./models/Qwen/Qwen2.5-7B-Instruct")
    
    # VieNeu TTS
    VIENEU_MODEL_DIR = os.getenv("VIENEU_MODEL_DIR", "vieneu-0.3B")
    VIENEU_VOICE_ID = os.getenv("VIENEU_VOICE_ID", "Ly")  # None = use default voice
    
    # WebSocket
    WS_HOST = os.getenv("WS_HOST", "0.0.0.0")
    WS_PORT = int(os.getenv("WS_PORT", "8765"))
    
    # Audio
    SAMPLE_RATE = 16000
    CHUNK_DURATION = 7  # seconds of audio to accumulate before transcription
    
    # LLM Response settings
    MAX_TOKENS = int(os.getenv("MAX_TOKENS", "150"))
    TEMPERATURE = float(os.getenv("TEMPERATURE", "0.3"))
    
    # Schedule file (FULL TEXT IN CONTEXT!)
    SCHEDULE_FILE = os.getenv("SCHEDULE_FILE", "./data/viet.txt")
    
    # System prompt (constrained to schedule)
    SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", """You are a vietnamese call centre agent.

CRITICAL RULES:
1. Answer ONLY in Vietnamese language
2. Keep responses SHORT (1-2 sentences maximum)
3. NO markdown, bullets, asterisks, or special formatting
4. Speak conversationally and naturally

Examples:
Q: "Alo, cho hỏi ngày mai khu em có cắt điện không ạ?"
A: "Vui lòng cho tôi biết địa điểm của bạn."

Q: "Alo, khu nhà tôi bị mất điện từ sáng nay, báo cáo giúp tôi với!"
A: "Tôi có thể biết tên địa phương của bạn không?"
""")
 # ============================================================================
# VOICE AGENT
# ============================================================================

class VoiceAgent:
    """Production voice agent with direct context approach"""
    
    def __init__(self,  use_schedule: bool = True):
        # Initialize ASR
        logger.info(f"Loading faster-asr  model")
        self.MyAsrModel = MyAsrModel(
           "model_ct2_fp16" ,
            device="cuda",
            compute_type="float16"
        )
        logger.info("✅ faster-asr  loaded")

        # Initialize VieNeu TTS
        logger.info(f"Loading VieNeu TTS from: {Config.VIENEU_MODEL_DIR}")
        self.tts = Vieneu(Config.VIENEU_MODEL_DIR)
        
        # Load preset voice if specified
        self.voice_data = None
        if Config.VIENEU_VOICE_ID:
            self.voice_data = self.tts.get_preset_voice(Config.VIENEU_VOICE_ID)
            logger.info(f"✅ VieNeu loaded with voice: {Config.VIENEU_VOICE_ID}")
        else:
            logger.info("✅ VieNeu loaded with default voice")

        # Initialize LLM client
        logger.info(f"Connecting to vLLM: {Config.VLLM_BASE_URL}")
        self.llm_client = AsyncOpenAI(
            base_url=Config.VLLM_BASE_URL,
            api_key="EMPTY"
        )
        
        # Load FULL schedule into memory (only ~400 tokens!)
        self.full_schedule = ""
        if use_schedule and os.path.exists(Config.SCHEDULE_FILE):
            logger.info(f"Loading full schedule from: {Config.SCHEDULE_FILE}")
            with open(Config.SCHEDULE_FILE, 'r') as f:
                self.full_schedule = f.read()
            
            # Calculate context usage
            estimated_tokens = len(self.full_schedule.split())
            context_usage = estimated_tokens / 131072 * 100  # 128K context
            
            logger.info(f"✅ Schedule loaded:")
            logger.info(f"   Characters: {len(self.full_schedule)}")
            logger.info(f"   Est. tokens: ~{estimated_tokens}")
            logger.info(f"   Context usage: {context_usage:.2f}%")
        
        # Conversation history
        self.messages = []
        
        # Audio buffer
        self.audio_buffer = bytearray()
        
        logger.info("✅ Voice agent initialized")
    
    def __del__(self):
        """Cleanup resources"""
        try:
            if hasattr(self, 'tts'):
                self.tts.close()
                logger.info("✅ VieNeu TTS closed")
        except Exception as e:
            logger.error(f"Error closing TTS: {e}")
    
    async def transcribe(self, audio_bytes: bytes) -> str:
        """Transcribe audio to text"""
        try:
            # Convert bytes to numpy array
            audio_array = np.frombuffer(audio_bytes, dtype=np.int16)
            audio_float = audio_array.astype(np.float32) / 32768.0
            
            # Transcribe

            segments, info = self.MyAsrModel.transcribe(
                audio_float,
                language="vi",
                vad_filter=True
            )

            # Collect text
            text = " ".join([seg.text for seg in segments]).strip()
            '''
            if(len(text) <5):
                logger.warning(f"Too short, ignoring: '{text}'")
                return ""
            if(len(text.split()) <2):
                logger.warning(f"Too few words, ignoring: '{text}'")
                return ""
            '''

            return text
            
        except Exception as e:
            logger.error(f"Transcription error: {e}")
            return ""
    
    def _format_conversation_history(self) -> str:
        """Format recent conversation history for context"""
        if len(self.messages) <= 1:
            return ""
        
        history_parts = []
        # Skip system prompt (index 0), format user/assistant pairs
        for i in range(1, len(self.messages), 2):
            if i + 1 < len(self.messages):
                user_msg = self.messages[i]["content"]
                assistant_msg = self.messages[i + 1]["content"]
                
                # Extract just the question and answer (strip RAG context)
                if "CURRENT QUESTION:" in user_msg:
                    question = user_msg.split("CURRENT QUESTION:")[1].split("INSTRUCTIONS:")[0].strip()
                elif "CAREGIVER QUESTION:" in user_msg:
                    question = user_msg.split("CAREGIVER QUESTION:")[1].split("INSTRUCTIONS:")[0].strip()
                else:
                    question = user_msg[:100]  # Fallback
                
                history_parts.append(f"Q: {question}\nA: {assistant_msg}")
        
        return "\n\n".join(history_parts[-2:])  # Last 2 exchanges
    
    def _is_valid_schedule_query(self, text: str) -> bool:
        """
        Validate if query is schedule-related and in English
        Filters out nonsense, background speech, non-English
        """
        text_lower = text.lower().strip()
        
        # Minimum length
        if len(text) < 5:
            logger.debug(f"Too short: '{text}'")
            return False
        
        # Must have at least 2 words
        words = text_lower.split()
        if len(words) < 2:
            logger.debug(f"Too few words: '{text}'")
            return False
        
        # Check for schedule-related keywords
        schedule_keywords = [
            # Time-related
            'time', 'when', 'what time', 'hour',
            # Activities
            'wake', 'sleep', 'bed', 'breakfast', 'lunch', 'dinner', 'snack', 'eat', 'meal',
            'bath', 'bathroom', 'toilet', 'wash',
            'medicine', 'medication', 'pill',
            # Questions
            'what', 'when', 'where', 'how', 'should', 'do', 'does',
            # Schedule terms
            'next', 'after', 'before', 'now', 'today', 'schedule', 'routine'
        ]
        
        has_schedule_keyword = any(keyword in text_lower for keyword in schedule_keywords)
        
        if not has_schedule_keyword:
            logger.debug(f"No schedule keywords: '{text}'")
            return False
        
        # Check for English (mostly ASCII)
        ascii_ratio = sum(1 for c in text if ord(c) < 128) / len(text)
        if ascii_ratio < 0.7:  # Less than 70% ASCII
            logger.debug(f"Non-English detected: '{text}' (ASCII: {ascii_ratio:.2f})")
            return False
        
        # Check for repeated characters (garbage)
        import re
        if re.search(r'(.)\1{4,}', text):  # Same char 5+ times
            logger.debug(f"Repeated chars (garbage): '{text}'")
            return False
        
        # All checks passed
        return True

    def clean_text_for_speech(self, text: str) -> str:
        """Clean text to make it TTS-friendly"""
        import re
        
        # Remove markdown formatting
        text = re.sub(r'\*+', '', text)  # Remove asterisks
        text = re.sub(r'_+', '', text)   # Remove underscores
        text = re.sub(r'#+\s*', '', text)  # Remove headers
        text = re.sub(r'\[.*?\]\(.*?\)', '', text)  # Remove links
        
        # Remove bullet points and list markers
        text = re.sub(r'^\s*[-•*]\s+', '', text, flags=re.MULTILINE)
        text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
        
        # Remove extra whitespace
        text = re.sub(r'\s+', ' ', text)
        text = text.strip()
        
        return text
    
    async def generate_response(self, user_text: str) -> str:
        """Generate response with FULL SCHEDULE IN CONTEXT"""
        try:
            '''
            # Pre-filter
            if not self._is_valid_schedule_query(user_text):
                logger.warning(f"Filtered: '{user_text}'")
                return ""
            '''
            # Build prompt with FULL schedule
            prompt = f"""You are a Vietnamese voice assistant. Reply to user questions in Vietnamese, no English:

{self.full_schedule}

===================

User question: {user_text}

Answer based ONLY on the schedule above. Keep it SHORT (1-2 sentences). NO markdown or bullets.

Answer:"""
            
            # Keep conversation history manageable
            if len(self.messages) > 6:  # Keep last 3 Q&A pairs
                self.messages = self.messages[-6:]
            
            # Add to history
            self.messages.append({"role": "user", "content": prompt})
            
            logger.info(f"User: {user_text}")
            
            # Generate with vLLM
            response = await self.llm_client.chat.completions.create(
                model=Config.VLLM_MODEL,
                messages=[
                    {"role": "system", "content": Config.SYSTEM_PROMPT},
                    *self.messages
                ],
                temperature=Config.TEMPERATURE,
                max_tokens=Config.MAX_TOKENS
            )
            
            assistant_text = response.choices[0].message.content.strip()
            
            # Add to history
            self.messages.append({"role": "assistant", "content": assistant_text})
            
            # Clean for TTS
            assistant_text_clean = self.clean_text_for_speech(assistant_text)
            
            logger.info(f"Assistant: {assistant_text_clean}")
            
            return assistant_text_clean
            
        except Exception as e:
            logger.error(f"LLM error: {e}")
            return "I'm having trouble processing that."
    
    def clean_text_for_speech(self, text: str) -> str:
        """Clean text for TTS"""
        import re
        
        # Remove markdown
        text = re.sub(r'\*+', '', text)
        text = re.sub(r'_+', '', text)
        text = re.sub(r'#+\s*', '', text)
        text = re.sub(r'\[.*?\]\(.*?\)', '', text)
        
        # Remove bullets
        text = re.sub(r'^\s*[-•*]\s+', '', text, flags=re.MULTILINE)
        text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
        
        # Normalize whitespace
        text = re.sub(r'\s+', ' ', text).strip()
        
        return text
    
    async def synthesize(self, text: str) -> bytes:
        """Synthesize speech with VieNeu TTS"""
        try:
            # Validate
            if not text or len(text.strip()) == 0:
                logger.warning("Empty text for TTS")
                return b''
            
            # Clean and validate text
            text = text.strip()
            if len(text) > 500:
                logger.warning(f"Text too long ({len(text)} chars), truncating")
                text = text[:500]
            
            # Normalize whitespace
            text = ' '.join(text.split())
            
            logger.debug(f"Synthesizing: {text}")
            
            # Generate audio with VieNeu
            # If voice_data is set, use it; otherwise use default voice
            audio_spec = self.tts.infer(text=text, voice=self.voice_data)
            
            # Save to temporary file
            audio_file = tempfile.mktemp(suffix='.wav')
            self.tts.save(audio_spec, audio_file)
            
            # Check if audio file exists and has content
            if not os.path.exists(audio_file):
                logger.error("VieNeu didn't create audio file")
                return b''
            
            if os.path.getsize(audio_file) < 100:
                logger.error(f"Audio file too small: {os.path.getsize(audio_file)} bytes")
                os.unlink(audio_file)
                return b''
            
            # Read audio with error handling
            try:
                with wave.open(audio_file, 'rb') as wav:
                    vieneu_sample_rate = wav.getframerate()
                    channels = wav.getnchannels()
                    audio_data = wav.readframes(wav.getnframes())
            except wave.Error as e:
                logger.error(f"WAV read error: {e}")
                os.unlink(audio_file)
                return b''
            
            # Resample if needed (VieNeu typically outputs at 22050 Hz)
            if vieneu_sample_rate != Config.SAMPLE_RATE:
                audio_array = np.frombuffer(audio_data, dtype=np.int16)
                
                # Use scipy.signal.resample (faster than np.interp)
                num_samples = int(len(audio_array) * Config.SAMPLE_RATE / vieneu_sample_rate)
                resampled = signal.resample(audio_array, num_samples)
                
                audio_data = resampled.astype(np.int16).tobytes()
            
            # Cleanup
            os.unlink(audio_file)
            
            logger.debug(f"✅ Synthesized {len(audio_data)} bytes of audio")
            return audio_data
            
        except Exception as e:
            logger.error(f"TTS error: {e}")
            # Cleanup on error
            try:
                if 'audio_file' in locals() and os.path.exists(audio_file):
                    os.unlink(audio_file)
            except:
                pass
            return b''
    
    async def process_audio(self, audio_chunk: bytes) -> Optional[bytes]:
        """Process incoming audio chunk"""
        # Accumulate audio
        self.audio_buffer.extend(audio_chunk)
        
        # Check if we have enough audio
        chunk_size = Config.SAMPLE_RATE * 2 * Config.CHUNK_DURATION  # 16-bit audio
        
        if len(self.audio_buffer) < chunk_size:
            return None  # Need more audio
        
        # Get audio to process
        audio_to_process = bytes(self.audio_buffer)
        self.audio_buffer.clear()
        
        # 1. Transcribe
        user_text = await self.transcribe(audio_to_process)
        
        if not user_text:
            return None
        
        # 2. Generate response
        assistant_text = await self.generate_response(user_text)
        
        # 3. Synthesize
        audio_response = await self.synthesize(assistant_text)
        
        return audio_response


# ============================================================================
# WEBSOCKET SERVER
# ============================================================================

async def handle_client(websocket):
    """Handle WebSocket client connection"""
    client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
    logger.info(f"✅ Client connected: {client_id}")
    
    # Create agent for this client (RAG-only system)
    agent = VoiceAgent(use_schedule=True)
    
    try:
        # Send welcome message
        welcome = await agent.synthesize("Xin chào, cảm ơn bạn đã sử dụng dịch vụ tổng đài.")
        if welcome:
            await websocket.send(welcome)
        
        # Process incoming audio
        async for message in websocket:
            if isinstance(message, bytes):
                # Process audio
                response_audio = await agent.process_audio(message)
                
                # Send response if available
                if response_audio:
                    await websocket.send(response_audio)
    
    except websockets.exceptions.ConnectionClosed:
        logger.info(f"Client disconnected: {client_id}")
    except Exception as e:
        logger.error(f"Error handling client {client_id}: {e}")
    finally:
        logger.info(f"Cleaned up connection: {client_id}")


async def start_server():
    """Start WebSocket server"""
    logger.info(f"Starting WebSocket server on {Config.WS_HOST}:{Config.WS_PORT}")
    
    async with websockets.serve(
        handle_client,
        Config.WS_HOST,
        Config.WS_PORT,
        max_size=10_000_000  # 10MB max message size
    ):
        logger.info(f"✅ Voice agent listening on ws://{Config.WS_HOST}:{Config.WS_PORT}")
        logger.info("Press Ctrl+C to stop")
        await asyncio.Future()  # Run forever


# ============================================================================
# STARTUP CHECKS
# ============================================================================

def check_dependencies():
    """Check if all dependencies are available"""
    errors = []
    
    # Check vLLM
    import requests
    try:
        response = requests.get(f"{Config.VLLM_BASE_URL}/models", timeout=5)
        if response.status_code == 200:
            models = response.json()
            logger.info(f"✅ vLLM running with {len(models.get('data', []))} models")
        else:
            errors.append(f"vLLM returned status {response.status_code}")
    except Exception as e:
        errors.append(f"Cannot connect to vLLM at {Config.VLLM_BASE_URL}: {e}")
    
    # Check VieNeu model directory
    vieneu_path = Path(Config.VIENEU_MODEL_DIR)
    if not vieneu_path.exists():
        errors.append(f"VieNeu model directory not found: {Config.VIENEU_MODEL_DIR}")
    else:
        logger.info(f"✅ VieNeu model directory found: {Config.VIENEU_MODEL_DIR}")
    
    # Check Whisper model
    if Config.WHISPER_MODEL not in ["tiny", "base", "small", "medium", "large"]:
        if not Path(Config.WHISPER_MODEL).exists():
            errors.append(f"Whisper model not found: {Config.WHISPER_MODEL}")
    
    logger.info(f"✅ Whisper model: {Config.WHISPER_MODEL}")
    
    if errors:
        logger.error("❌ Dependency check failed:")
        for error in errors:
            logger.error(f"  - {error}")
        return False
    
    return True


# ============================================================================
# MAIN
# ============================================================================

async def main():
    """Main entry point"""
    logger.info("="*70)
    logger.info("🎙️  LOCAL VOICE AGENT")
    logger.info("="*70)
    logger.info(f"ASR Model:    {Config.WHISPER_MODEL} ({Config.WHISPER_DEVICE})")
    logger.info(f"LLM Server:   {Config.VLLM_BASE_URL}")
    logger.info(f"LLM Model:    {Config.VLLM_MODEL}")
    logger.info(f"TTS Model:    VieNeu ({Config.VIENEU_MODEL_DIR})")
    logger.info(f"WebSocket:    ws://{Config.WS_HOST}:{Config.WS_PORT}")
    logger.info("="*70)
    logger.info("")
    
    # Check dependencies
    if not check_dependencies():
        logger.error("Please fix the errors above and try again")
        logger.error("")
        logger.error("Quick fixes:")
        logger.error("  1. Start vLLM: ./start_vllm.sh")
        logger.error("  2. Download VieNeu model: place in vieneu-0.3B directory")
        logger.error("  3. Check model paths in .env file")
        sys.exit(1)
    
    logger.info("")
    
    # Start server
    try:
        await start_server()
    except KeyboardInterrupt:
        logger.info("")
        logger.info("Shutting down...")


if __name__ == "__main__":
    asyncio.run(main())
