#!/usr/bin/env python3
"""
Phase 1: User Validation State Machine (Hybrid Approach)
Electricity Company Call Center - Vietnam

State Flow:
1. GREETING → Agent greets
2. AWAIT_PHONE_REQUEST → Ask for phone number
3. COLLECTING_PHONE → Confirm phone number
4. AWAIT_NAME_REQUEST → Ask for name
5. COLLECTING_NAME → Confirm name
6. VALIDATING_USER → Check against database
7. VALIDATION_COMPLETE → Inform result

Hybrid Approach:
- Python manages state transitions (deterministic)
- LLM extracts data (phone, name) and generates natural responses
"""

import asyncio
import os
import sys
import json
from pathlib import Path
from typing import Optional, Dict, Any, Tuple
from dataclasses import dataclass
from enum import Enum

# Core imports
import websockets
from loguru import logger
from dotenv import load_dotenv

# Local model imports
from openai import AsyncOpenAI
import tempfile
import wave
import numpy as np
from scipy import signal
from faster_asr import MyAsrModel
from vieneu import Vieneu

# Load environment
load_dotenv()


# ============================================================================
# CONFIGURATION
# ============================================================================

class Config:
    """Configuration for validation system"""
    
    # ASR
    WHISPER_LANGUAGE = os.getenv("WHISPER_LANGUAGE", "vi")
    
    # vLLM
    VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1")
    VLLM_MODEL = os.getenv("VLLM_MODEL", "./models/Qwen/Qwen2.5-7B-Instruct")
    
    # VieNeu TTS
    VIENEU_MODEL_DIR = os.getenv("VIENEU_MODEL_DIR", "vieneu-0.3B")
    VIENEU_VOICE_ID = os.getenv("VIENEU_VOICE_ID", "Ly")
    
    # WebSocket
    WS_HOST = os.getenv("WS_HOST", "0.0.0.0")
    WS_PORT = int(os.getenv("WS_PORT", "8765"))
    
    # Audio
    SAMPLE_RATE = 16000
    CHUNK_DURATION = 5  # seconds
    
    # LLM settings
    MAX_TOKENS = int(os.getenv("MAX_TOKENS", "150"))
    TEMPERATURE = float(os.getenv("TEMPERATURE", "0.3"))
    
    # Validation settings
    MAX_RETRY_PHONE = 3
    MAX_RETRY_NAME = 3
    PHONE_CONFIDENCE_THRESHOLD = 0.7
    NAME_CONFIDENCE_THRESHOLD = 0.6
    
    # Customer database (simple CSV for now)
    CUSTOMER_DB_FILE = os.getenv("CUSTOMER_DB_FILE", "./data/customers.csv")


# ============================================================================
# DATA MODELS
# ============================================================================

class ValidationState(Enum):
    """State machine states"""
    GREETING = "GREETING"
    AWAIT_PHONE_REQUEST = "AWAIT_PHONE_REQUEST"
    COLLECTING_PHONE = "COLLECTING_PHONE"
    AWAIT_NAME_REQUEST = "AWAIT_NAME_REQUEST"
    COLLECTING_NAME = "COLLECTING_NAME"
    VALIDATING_USER = "VALIDATING_USER"
    VALIDATION_COMPLETE = "VALIDATION_COMPLETE"
    ESCALATE_TO_HUMAN = "ESCALATE_TO_HUMAN"


@dataclass
class ExtractionResult:
    """Result from LLM extraction"""
    success: bool
    value: Optional[str]
    confidence: float
    needs_confirmation: bool
    failure_reason: Optional[str] = None
    user_provided_voluntarily: bool = False


@dataclass
class ValidationContext:
    """Context maintained throughout validation"""
    phone_number: Optional[str] = None
    phone_confirmed: bool = False
    customer_name: Optional[str] = None
    name_confirmed: bool = False
    database_name: Optional[str] = None
    account_id: Optional[str] = None
    customer_status: Optional[str] = None  # existing_verified | existing_mismatch | new_customer
    retry_count_phone: int = 0
    retry_count_name: int = 0
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for logging"""
        return {
            "phone": self.phone_number,
            "phone_confirmed": self.phone_confirmed,
            "name": self.customer_name,
            "name_confirmed": self.name_confirmed,
            "status": self.customer_status,
            "retries": {
                "phone": self.retry_count_phone,
                "name": self.retry_count_name
            }
        }


# ============================================================================
# CUSTOMER DATABASE
# ============================================================================

class CustomerDatabase:
    """Simple customer database (CSV-based)"""
    
    def __init__(self, db_file: str):
        self.db_file = db_file
        self.customers = {}
        self._load_database()
    
    def _load_database(self):
        """Load customer data from CSV"""
        if not os.path.exists(self.db_file):
            logger.warning(f"Customer database not found: {self.db_file}")
            logger.info("Creating sample database...")
            self._create_sample_db()
        
        try:
            with open(self.db_file, 'r', encoding='utf-8') as f:
                lines = f.readlines()
                for i, line in enumerate(lines):
                    if i == 0:  # Skip header
                        continue
                    parts = line.strip().split(',')
                    if len(parts) >= 3:
                        phone, name, account_id = parts[0], parts[1], parts[2]
                        self.customers[phone] = {
                            'name': name,
                            'account_id': account_id,
                            'phone': phone
                        }
            logger.info(f"✅ Loaded {len(self.customers)} customers from database")
        except Exception as e:
            logger.error(f"Error loading database: {e}")
    
    def _create_sample_db(self):
        """Create sample database for testing"""
        os.makedirs(os.path.dirname(self.db_file), exist_ok=True)
        sample_data = """phone_number,customer_name,account_id,registration_date
0901234567,Nguyễn Văn An,KH001,2023-01-15
0912345678,Trần Thị Bình,KH002,2023-02-20
0923456789,Lê Văn Công,KH003,2023-03-10
0934567890,Phạm Thị Dung,KH004,2023-04-05
0945678901,Hoàng Văn Em,KH005,2023-05-12
"""
        with open(self.db_file, 'w', encoding='utf-8') as f:
            f.write(sample_data)
        logger.info(f"✅ Created sample database: {self.db_file}")
    
    def lookup(self, phone: str) -> Optional[Dict[str, str]]:
        """Lookup customer by phone number"""
        return self.customers.get(phone)
    
    def fuzzy_match_name(self, input_name: str, db_name: str) -> float:
        """
        Fuzzy match Vietnamese names
        Returns similarity score 0.0-1.0
        """
        # Normalize: remove diacritics and lowercase
        def normalize(text: str) -> str:
            # Simple normalization (can be improved with unicodedata)
            text = text.lower()
            text = text.replace('đ', 'd')
            # Remove common diacritics
            replacements = {
                'á': 'a', 'à': 'a', 'ả': 'a', 'ã': 'a', 'ạ': 'a',
                'ă': 'a', 'ắ': 'a', 'ằ': 'a', 'ẳ': 'a', 'ẵ': 'a', 'ặ': 'a',
                'â': 'a', 'ấ': 'a', 'ầ': 'a', 'ẩ': 'a', 'ẫ': 'a', 'ậ': 'a',
                'é': 'e', 'è': 'e', 'ẻ': 'e', 'ẽ': 'e', 'ẹ': 'e',
                'ê': 'e', 'ế': 'e', 'ề': 'e', 'ể': 'e', 'ễ': 'e', 'ệ': 'e',
                'í': 'i', 'ì': 'i', 'ỉ': 'i', 'ĩ': 'i', 'ị': 'i',
                'ó': 'o', 'ò': 'o', 'ỏ': 'o', 'õ': 'o', 'ọ': 'o',
                'ô': 'o', 'ố': 'o', 'ồ': 'o', 'ổ': 'o', 'ỗ': 'o', 'ộ': 'o',
                'ơ': 'o', 'ớ': 'o', 'ờ': 'o', 'ở': 'o', 'ỡ': 'o', 'ợ': 'o',
                'ú': 'u', 'ù': 'u', 'ủ': 'u', 'ũ': 'u', 'ụ': 'u',
                'ư': 'u', 'ứ': 'u', 'ừ': 'u', 'ử': 'u', 'ữ': 'u', 'ự': 'u',
                'ý': 'y', 'ỳ': 'y', 'ỷ': 'y', 'ỹ': 'y', 'ỵ': 'y',
            }
            for viet, ascii_char in replacements.items():
                text = text.replace(viet, ascii_char)
            return ''.join(text.split())  # Remove spaces
        
        norm_input = normalize(input_name)
        norm_db = normalize(db_name)
        
        # Simple similarity: check if one contains the other or vice versa
        if norm_input == norm_db:
            return 1.0
        elif norm_input in norm_db or norm_db in norm_input:
            return 0.85
        else:
            # Calculate Levenshtein-like similarity
            longer = max(len(norm_input), len(norm_db))
            if longer == 0:
                return 0.0
            
            # Count matching characters
            matches = sum(1 for a, b in zip(norm_input, norm_db) if a == b)
            return matches / longer


# ============================================================================
# LLM FUNCTIONS
# ============================================================================

class LLMFunctions:
    """LLM function calling for data extraction and response generation"""
    
    def __init__(self, client: AsyncOpenAI):
        self.client = client
    
    async def _call_llm(self, messages: list, temperature: float = 0.3) -> str:
        """Helper to call LLM"""
        try:
            response = await self.client.chat.completions.create(
                model=Config.VLLM_MODEL,
                messages=messages,
                temperature=temperature,
                max_tokens=Config.MAX_TOKENS
            )
            return response.choices[0].message.content.strip()
        except Exception as e:
            logger.error(f"LLM call error: {e}")
            return ""
    
    async def extract_phone_number(self, user_speech: str, context: ValidationContext) -> ExtractionResult:
        print("## In extract_phone_number ....\n")
        """
        Extract phone number from Vietnamese speech
        Returns: ExtractionResult with phone number
        """
        system_prompt = """Bạn là trợ lý trích xuất số điện thoại từ cuộc gọi của khách hàng Việt Nam.

NHIỆM VỤ: Trích xuất số điện thoại từ câu nói của khách hàng.

QUY TẮC:
1. Định dạng số điện thoại: 10 chữ số, bắt đầu bằng 0 (VD: 0901234567)
2. Xử lý các từ lắp (uh, à, thì, ờ, hmm)
3. Nếu khách nói nhiều số, chọn số mà khách chỉ rõ là số của họ
4. Trả về độ tin cậy dựa trên độ rõ ràng của khách

OUTPUT FORMAT (JSON):
{
  "phone_number": "0901234567" hoặc null,
  "confidence": 0.0-1.0,
  "needs_confirmation": true/false,
  "failure_reason": "lý do nếu thất bại" hoặc null,
  "user_hesitant": true/false
}

VÍ DỤ:
Input: "Số tôi là 0901234567 ạ"
Output: {"phone_number": "0901234567", "confidence": 0.95, "needs_confirmation": false, "failure_reason": null, "user_hesitant": false}

Input: "Uh... cho tôi nhớ... 090... 0901234567"
Output: {"phone_number": "0901234567", "confidence": 0.75, "needs_confirmation": true, "failure_reason": null, "user_hesitant": true}

Input: "Tôi không nhớ số"
Output: {"phone_number": null, "confidence": 0.0, "needs_confirmation": false, "failure_reason": "khách không nhớ số", "user_hesitant": false}
"""
        
        user_prompt = f"""Khách hàng nói: "{user_speech}"

Trích xuất số điện thoại và trả về JSON:"""
        
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]
        
        response = await self._call_llm(messages, temperature=0.1)
        print(f"## response={response}")
        
        # Parse JSON response
        try:
            # Extract JSON from response (handle markdown code blocks)
            if "```json" in response:
                response = response.split("```json")[1].split("```")[0].strip()
            elif "```" in response:
                response = response.split("```")[1].split("```")[0].strip()
            
            data = json.loads(response)
            
            phone = data.get("phone_number")
            confidence = data.get("confidence", 0.0)
            
            # Validate phone format
            if phone:
                phone = phone.strip().replace(" ", "").replace("-", "")
                if not (phone.startswith("0") and len(phone) == 10 and phone.isdigit()):
                    return ExtractionResult(
                        success=False,
                        value=None,
                        confidence=0.0,
                        needs_confirmation=False,
                        failure_reason="Định dạng số không hợp lệ"
                    )
            
            return ExtractionResult(
                success=phone is not None,
                value=phone,
                confidence=confidence,
                needs_confirmation=data.get("needs_confirmation", confidence < Config.PHONE_CONFIDENCE_THRESHOLD),
                failure_reason=data.get("failure_reason")
            )
        
        except Exception as e:
            logger.error(f"Error parsing LLM phone extraction: {e}, Response: {response}")
            return ExtractionResult(
                success=False,
                value=None,
                confidence=0.0,
                needs_confirmation=False,
                failure_reason="Lỗi xử lý"
            )
    
    async def extract_name(self, user_speech: str, context: ValidationContext) -> ExtractionResult:
        """
        Extract customer name from Vietnamese speech
        Returns: ExtractionResult with name
        """
        system_prompt = """Bạn là trợ lý trích xuất tên khách hàng từ cuộc gọi.

NHIỆM VỤ: Trích xuất họ và tên của khách hàng.

QUY TẮC:
1. Tên Việt Nam thường có 2-4 từ (VD: Nguyễn Văn An, Trần Thị Bình)
2. Xử lý các từ lắp và lời xưng hô (tôi là, tên tôi là, con/cháu/em tên là)
3. Giữ nguyên dấu tiếng Việt
4. Trả về độ tin cậy dựa trên độ rõ ràng

OUTPUT FORMAT (JSON):
{
  "name": "Nguyễn Văn An" hoặc null,
  "confidence": 0.0-1.0,
  "needs_confirmation": true/false,
  "failure_reason": "lý do nếu thất bại" hoặc null
}

VÍ DỤ:
Input: "Tôi là Nguyễn Văn An ạ"
Output: {"name": "Nguyễn Văn An", "confidence": 0.95, "needs_confirmation": false, "failure_reason": null}

Input: "Tên... ờ... Trần Thị Bình"
Output: {"name": "Trần Thị Bình", "confidence": 0.70, "needs_confirmation": true, "failure_reason": null}

Input: "Tôi quên mất rồi"
Output: {"name": null, "confidence": 0.0, "needs_confirmation": false, "failure_reason": "khách quên tên"}
"""
        
        user_prompt = f"""Khách hàng nói: "{user_speech}"

Trích xuất tên và trả về JSON:"""
        
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]
        
        response = await self._call_llm(messages, temperature=0.1)
        
        # Parse JSON response
        try:
            if "```json" in response:
                response = response.split("```json")[1].split("```")[0].strip()
            elif "```" in response:
                response = response.split("```")[1].split("```")[0].strip()
            
            data = json.loads(response)
            
            name = data.get("name")
            confidence = data.get("confidence", 0.0)
            
            # Basic validation
            if name:
                name = name.strip()
                if len(name.split()) < 2:
                    return ExtractionResult(
                        success=False,
                        value=None,
                        confidence=0.0,
                        needs_confirmation=False,
                        failure_reason="Tên không đầy đủ"
                    )
            
            return ExtractionResult(
                success=name is not None,
                value=name,
                confidence=confidence,
                needs_confirmation=data.get("needs_confirmation", confidence < Config.NAME_CONFIDENCE_THRESHOLD),
                failure_reason=data.get("failure_reason")
            )
        
        except Exception as e:
            logger.error(f"Error parsing LLM name extraction: {e}, Response: {response}")
            return ExtractionResult(
                success=False,
                value=None,
                confidence=0.0,
                needs_confirmation=False,
                failure_reason="Lỗi xử lý"
            )
    
    async def generate_response(self, state: ValidationState, context: ValidationContext, 
                               extra_info: Optional[Dict[str, Any]] = None) -> str:
        """
        Generate natural Vietnamese response based on state and context
        """
        system_prompt = """Bạn là nhân viên tổng đài điện lực, thân thiện và chuyên nghiệp.

QUY TẮC:
1. Nói ngắn gọn (1-2 câu)
2. Dùng "quý khách", "anh/chị" lịch sự
3. Không dùng markdown, dấu đầu dòng, hay ký tự đặc biệt
4. Tự nhiên như người thật

TRẠNG THÁI VÀ CÂU TRẢ LỜI MẪU:
"""
        
        state_templates = {
            ValidationState.GREETING: "Chào bạn tới với tổng đài điện lực.Để phục vụ quý khách tốt hơn, xin cho biết số điện thoại của quý khách ạ?",
            
            ValidationState.AWAIT_PHONE_REQUEST: "Để phục vụ quý khách tốt hơn, xin cho biết số điện thoại của quý khách ạ?",
            
            ValidationState.COLLECTING_PHONE: f"Xin cảm ơn. Số điện thoại của quý khách là {context.phone_number}. Đúng không ạ?",
            
            ValidationState.AWAIT_NAME_REQUEST: "Xin cho biết họ và tên của quý khách ạ?",
            
            ValidationState.COLLECTING_NAME: f"Tên quý khách là {context.customer_name}. Đúng không ạ?",
            
            ValidationState.VALIDATION_COMPLETE: self._get_validation_complete_template(context)
        }
        
        # Handle retry scenarios
        if extra_info:
            if extra_info.get("retry"):
                retry_count = extra_info.get("retry_count", 0)
                if state == ValidationState.AWAIT_PHONE_REQUEST:
                    if retry_count == 1:
                        return "Xin lỗi, tôi chưa nghe rõ số điện thoại. Quý khách vui lòng nói lại số điện thoại ạ?"
                    elif retry_count >= 2:
                        return "Tôi vẫn chưa nghe rõ số. Quý khách có thể nói từng số một được không ạ? Ví dụ: không chín không một..."
                
                elif state == ValidationState.AWAIT_NAME_REQUEST:
                    if retry_count == 1:
                        return "Xin lỗi, tôi chưa nghe rõ tên. Quý khách vui lòng nói lại họ và tên ạ?"
                    elif retry_count >= 2:
                        return "Quý khách vui lòng nói chậm và rõ họ tên được không ạ?"
        
        # Return template for state
        template = state_templates.get(state, "")
        
        if not template:
            logger.warning(f"No template for state: {state}")
            return "Xin lỗi, có lỗi xảy ra. Vui lòng thử lại."
        
        return template
    
    def _get_validation_complete_template(self, context: ValidationContext) -> str:
        """Get response template for validation complete state"""
        if context.customer_status == "existing_verified":
            return f"Xin chào anh/chị {context.customer_name}. Hệ thống đã xác nhận thông tin. Quý khách cần hỗ trợ gì ạ?"
        
        elif context.customer_status == "existing_mismatch":
            return f"Số điện thoại này đã đăng ký với tên {context.database_name}. Quý khách có phải là người đăng ký không ạ?"
        
        elif context.customer_status == "new_customer":
            return "Hệ thống chưa có thông tin của quý khách. Quý khách muốn đăng ký dịch vụ mới hay tra cứu thông tin ạ?"
        
        return "Đã xác nhận thông tin. Quý khách cần hỗ trợ gì ạ?"


# ============================================================================
# VALIDATION STATE MACHINE
# ============================================================================

class ValidationStateMachine:
    """
    State machine for Phase 1: User Validation
    Uses hybrid approach: Python controls states, LLM extracts data & generates responses
    """
    
    def __init__(self, llm_client: AsyncOpenAI, customer_db: CustomerDatabase):
        self.state = ValidationState.GREETING
        self.context = ValidationContext()
        self.llm_functions = LLMFunctions(llm_client)
        self.customer_db = customer_db
        self.pending_confirmation = None  # Track what we're confirming
        
        logger.info("🔄 Validation state machine initialized")
    
    def is_complete(self) -> bool:
        """Check if validation is complete"""
        return self.state == ValidationState.VALIDATION_COMPLETE
    
    def get_context(self) -> ValidationContext:
        """Get current validation context"""
        return self.context
    
    async def start(self) -> str:
        """
        Start the validation flow
        Returns: Greeting message
        """
        self.state = ValidationState.GREETING
        response = await self.llm_functions.generate_response(self.state, self.context)
        
        # Auto-transition to next state
        self.state = ValidationState.AWAIT_PHONE_REQUEST
        
        return response
    
    async def process(self, user_speech: str) -> str:
        """
        Process user input based on current state
        Returns: Agent response
        """
        logger.info(f"📍 State: {self.state.value} | User: {user_speech}")
        
        # State transition logic
        if self.state == ValidationState.AWAIT_PHONE_REQUEST:
            return await self._handle_await_phone(user_speech)
        
        elif self.state == ValidationState.COLLECTING_PHONE:
            return await self._handle_collecting_phone(user_speech)
        
        elif self.state == ValidationState.AWAIT_NAME_REQUEST:
            return await self._handle_await_name(user_speech)
        
        elif self.state == ValidationState.COLLECTING_NAME:
            return await self._handle_collecting_name(user_speech)
        
        elif self.state == ValidationState.VALIDATION_COMPLETE:
            return await self._handle_validation_complete(user_speech)
        
        else:
            logger.error(f"Unknown state: {self.state}")
            return "Xin lỗi, có lỗi xảy ra."
    
    async def _handle_await_phone(self, user_speech: str) -> str:
        """Handle AWAIT_PHONE_REQUEST state"""
        # Extract phone number using LLM
        print("### In _handle_await_phone ...\n")
        result = await self.llm_functions.extract_phone_number(user_speech, self.context)
        print(f"## result={result}")
        
        if result.success and result.value:
            # Phone extracted successfully
            self.context.phone_number = result.value
            logger.info(f"✅ Extracted phone: {result.value} (confidence: {result.confidence})")
            
            if result.confidence > 0.9 and not result.needs_confirmation:
                # High confidence, skip confirmation
                self.context.phone_confirmed = True
                self.state = ValidationState.AWAIT_NAME_REQUEST
                return await self.llm_functions.generate_response(self.state, self.context)
            else:
                # Need confirmation
                self.state = ValidationState.COLLECTING_PHONE
                self.pending_confirmation = "phone"
                return await self.llm_functions.generate_response(self.state, self.context)
        else:
            # Extraction failed
            self.context.retry_count_phone += 1
            logger.warning(f"❌ Phone extraction failed (attempt {self.context.retry_count_phone}): {result.failure_reason}")
            
            if self.context.retry_count_phone >= Config.MAX_RETRY_PHONE:
                self.state = ValidationState.ESCALATE_TO_HUMAN
                return "Xin lỗi quý khách. Để được hỗ trợ tốt hơn, tôi sẽ chuyển máy cho nhân viên. Xin vui lòng chờ trong giây lát."
            
            # Generate retry prompt
            return await self.llm_functions.generate_response(
                self.state, 
                self.context,
                extra_info={"retry": True, "retry_count": self.context.retry_count_phone}
            )
    
    async def _handle_collecting_phone(self, user_speech: str) -> str:
        """Handle COLLECTING_PHONE state (confirmation)"""
        # Check for confirmation or rejection
        user_lower = user_speech.lower()
        
        # Vietnamese confirmation words
        confirmations = ['đúng', 'vâng', 'ừ', 'có', 'phải', 'yes', 'correct', 'ok']
        rejections = ['không', 'sai', 'chưa', 'no', 'wrong']
        
        is_confirmed = any(word in user_lower for word in confirmations)
        is_rejected = any(word in user_lower for word in rejections)
        
        if is_confirmed:
            # Phone confirmed
            self.context.phone_confirmed = True
            self.state = ValidationState.AWAIT_NAME_REQUEST
            logger.info(f"✅ Phone confirmed: {self.context.phone_number}")
            return await self.llm_functions.generate_response(self.state, self.context)
        
        elif is_rejected:
            # Phone rejected, ask again
            self.context.phone_number = None
            self.context.retry_count_phone += 1
            self.state = ValidationState.AWAIT_PHONE_REQUEST
            logger.info(f"❌ Phone rejected by user")
            
            if self.context.retry_count_phone >= Config.MAX_RETRY_PHONE:
                self.state = ValidationState.ESCALATE_TO_HUMAN
                return "Xin lỗi quý khách. Để được hỗ trợ tốt hơn, tôi sẽ chuyển máy cho nhân viên."
            
            return await self.llm_functions.generate_response(
                self.state,
                self.context,
                extra_info={"retry": True, "retry_count": self.context.retry_count_phone}
            )
        
        else:
            # Unclear response, assume confirmed after timeout
            # (In real system, you'd handle timeout separately)
            self.context.phone_confirmed = True
            self.state = ValidationState.AWAIT_NAME_REQUEST
            logger.info(f"⚠️  Unclear confirmation, assuming yes")
            return await self.llm_functions.generate_response(self.state, self.context)
    
    async def _handle_await_name(self, user_speech: str) -> str:
        """Handle AWAIT_NAME_REQUEST state"""
        # Extract name using LLM
        result = await self.llm_functions.extract_name(user_speech, self.context)
        
        if result.success and result.value:
            # Name extracted successfully
            self.context.customer_name = result.value
            logger.info(f"✅ Extracted name: {result.value} (confidence: {result.confidence})")
            
            if result.confidence > 0.85 and not result.needs_confirmation:
                # High confidence, skip confirmation
                self.context.name_confirmed = True
                self.state = ValidationState.VALIDATING_USER
                return await self._validate_user()
            else:
                # Need confirmation
                self.state = ValidationState.COLLECTING_NAME
                self.pending_confirmation = "name"
                return await self.llm_functions.generate_response(self.state, self.context)
        else:
            # Extraction failed
            self.context.retry_count_name += 1
            logger.warning(f"❌ Name extraction failed (attempt {self.context.retry_count_name}): {result.failure_reason}")
            
            if self.context.retry_count_name >= Config.MAX_RETRY_NAME:
                self.state = ValidationState.ESCALATE_TO_HUMAN
                return "Xin lỗi quý khách. Để được hỗ trợ tốt hơn, tôi sẽ chuyển máy cho nhân viên."
            
            # Generate retry prompt
            return await self.llm_functions.generate_response(
                self.state,
                self.context,
                extra_info={"retry": True, "retry_count": self.context.retry_count_name}
            )
    
    async def _handle_collecting_name(self, user_speech: str) -> str:
        """Handle COLLECTING_NAME state (confirmation)"""
        user_lower = user_speech.lower()
        
        confirmations = ['đúng', 'vâng', 'ừ', 'có', 'phải', 'yes', 'correct']
        rejections = ['không', 'sai', 'chưa', 'no', 'wrong']
        
        is_confirmed = any(word in user_lower for word in confirmations)
        is_rejected = any(word in user_lower for word in rejections)
        
        if is_confirmed:
            # Name confirmed
            self.context.name_confirmed = True
            self.state = ValidationState.VALIDATING_USER
            logger.info(f"✅ Name confirmed: {self.context.customer_name}")
            return await self._validate_user()
        
        elif is_rejected:
            # Name rejected, ask again
            self.context.customer_name = None
            self.context.retry_count_name += 1
            self.state = ValidationState.AWAIT_NAME_REQUEST
            logger.info(f"❌ Name rejected by user")
            
            if self.context.retry_count_name >= Config.MAX_RETRY_NAME:
                self.state = ValidationState.ESCALATE_TO_HUMAN
                return "Xin lỗi quý khách. Để được hỗ trợ tốt hơn, tôi sẽ chuyển máy cho nhân viên."
            
            return await self.llm_functions.generate_response(
                self.state,
                self.context,
                extra_info={"retry": True, "retry_count": self.context.retry_count_name}
            )
        
        else:
            # Unclear, assume confirmed
            self.context.name_confirmed = True
            self.state = ValidationState.VALIDATING_USER
            logger.info(f"⚠️  Unclear confirmation, assuming yes")
            return await self._validate_user()
    
    async def _validate_user(self) -> str:
        """
        Validate user against database
        This is pure Python - no LLM needed
        """
        logger.info(f"🔍 Validating: {self.context.phone_number} | {self.context.customer_name}")
        
        # Lookup in database
        db_customer = self.customer_db.lookup(self.context.phone_number)
        
        if db_customer:
            # Phone exists in database
            self.context.database_name = db_customer['name']
            self.context.account_id = db_customer['account_id']
            
            # Fuzzy match name
            similarity = self.customer_db.fuzzy_match_name(
                self.context.customer_name,
                db_customer['name']
            )
            
            logger.info(f"📊 Name similarity: {similarity:.2f}")
            
            if similarity > 0.8:
                # Name matches
                self.context.customer_status = "existing_verified"
                logger.info(f"✅ Customer verified: {db_customer['account_id']}")
            else:
                # Name mismatch
                self.context.customer_status = "existing_mismatch"
                logger.warning(f"⚠️  Name mismatch: {self.context.customer_name} vs {db_customer['name']}")
        else:
            # Phone not in database
            self.context.customer_status = "new_customer"
            logger.info(f"🆕 New customer: {self.context.phone_number}")
        
        # Transition to complete
        self.state = ValidationState.VALIDATION_COMPLETE
        
        # Generate response using LLM
        response = await self.llm_functions.generate_response(self.state, self.context)
        
        return response
    
    async def _handle_validation_complete(self, user_speech: str) -> str:
        """Handle VALIDATION_COMPLETE state"""
        # This state means validation is done
        # In Phase 2, we would transition to intent detection here
        return "Validation complete. (Ready for Phase 2: Intent Detection)"


# ============================================================================
# VOICE AGENT (INTEGRATED)
# ============================================================================

class VoiceAgentWithValidation:
    """Voice agent with Phase 1 validation integrated"""
    
    def __init__(self):
        # Initialize ASR
        logger.info(f"Loading ASR model...")
        self.asr_model = MyAsrModel(
            "model_ct2_fp16",
            device="cuda",
            compute_type="float16"
        )
        logger.info("✅ ASR loaded")
        
        # Initialize LLM client
        logger.info(f"Connecting to vLLM: {Config.VLLM_BASE_URL}")
        self.llm_client = AsyncOpenAI(
            base_url=Config.VLLM_BASE_URL,
            api_key="EMPTY"
        )
        logger.info("✅ LLM client connected")
        
        # Initialize TTS
        logger.info(f"Loading VieNeu TTS from: {Config.VIENEU_MODEL_DIR}")
        self.tts = Vieneu(Config.VIENEU_MODEL_DIR)
        self.voice_data = None
        if Config.VIENEU_VOICE_ID:
            self.voice_data = self.tts.get_preset_voice(Config.VIENEU_VOICE_ID)
            logger.info(f"✅ VieNeu loaded with voice: {Config.VIENEU_VOICE_ID}")
        else:
            logger.info("✅ VieNeu loaded with default voice")
        
        # Initialize customer database
        self.customer_db = CustomerDatabase(Config.CUSTOMER_DB_FILE)
        
        # Initialize validation state machine
        self.validation_sm = ValidationStateMachine(self.llm_client, self.customer_db)
        
        # Audio buffer
        self.audio_buffer = bytearray()
        
        logger.info("✅ Voice agent with validation initialized")
    
    def __del__(self):
        """Cleanup"""
        try:
            if hasattr(self, 'tts'):
                self.tts.close()
                logger.info("✅ TTS closed")
        except Exception as e:
            logger.error(f"Error closing TTS: {e}")
    
    async def transcribe(self, audio_bytes: bytes) -> str:
        print("## In transcribe ........\n")
        """Transcribe audio to text"""
        try:
            # Convert bytes to numpy array
            audio_array = np.frombuffer(audio_bytes, dtype=np.int16)
            audio_float = audio_array.astype(np.float32) / 32768.0
            
            # Transcribe
            segments, info = self.asr_model.transcribe(
                audio_float,
                language="vi",
                vad_filter=False
            )
            
            # Collect text
            text = " ".join([seg.text for seg in segments]).strip()
            
            if len(text) < 1:
                logger.debug(f"Too short: '{text}'")
                return ""
            
            if len(text.split()) < 1:
                logger.debug(f"Too few words: '{text}'")
                return ""
            
            logger.info(f"🎤 Transcribed: {text}")
            return text
        
        except Exception as e:
            logger.error(f"Transcription error: {e}")
            return ""
    
    async def synthesize(self, text: str) -> bytes:
        """Synthesize speech with VieNeu TTS"""
        try:
            if not text or len(text.strip()) == 0:
                logger.warning("Empty text for TTS")
                return b''
            
            text = text.strip()
            if len(text) > 500:
                logger.warning(f"Text too long ({len(text)} chars), truncating")
                text = text[:500]
            
            text = ' '.join(text.split())
            
            logger.debug(f"🔊 Synthesizing: {text}")
            
            # Generate audio with VieNeu
            audio_spec = self.tts.infer(text=text, voice=self.voice_data)
            
            # Save to temporary file
            audio_file = tempfile.mktemp(suffix='.wav')
            self.tts.save(audio_spec, audio_file)
            
            if not os.path.exists(audio_file):
                logger.error("VieNeu didn't create audio file")
                return b''
            
            if os.path.getsize(audio_file) < 100:
                logger.error(f"Audio file too small: {os.path.getsize(audio_file)} bytes")
                os.unlink(audio_file)
                return b''
            
            # Read audio
            try:
                with wave.open(audio_file, 'rb') as wav:
                    vieneu_sample_rate = wav.getframerate()
                    audio_data = wav.readframes(wav.getnframes())
            except wave.Error as e:
                logger.error(f"WAV read error: {e}")
                os.unlink(audio_file)
                return b''
            
            # Resample if needed
            if vieneu_sample_rate != Config.SAMPLE_RATE:
                audio_array = np.frombuffer(audio_data, dtype=np.int16)
                num_samples = int(len(audio_array) * Config.SAMPLE_RATE / vieneu_sample_rate)
                resampled = signal.resample(audio_array, num_samples)
                audio_data = resampled.astype(np.int16).tobytes()
            
            os.unlink(audio_file)
            
            logger.debug(f"✅ Synthesized {len(audio_data)} bytes")
            return audio_data
        
        except Exception as e:
            logger.error(f"TTS error: {e}")
            try:
                if 'audio_file' in locals() and os.path.exists(audio_file):
                    os.unlink(audio_file)
            except:
                pass
            return b''
    
    async def process_audio(self, audio_chunk: bytes) -> Optional[bytes]:
        #print("## In process_audio........\n")
        """Process incoming audio chunk"""
        # Accumulate audio
        self.audio_buffer.extend(audio_chunk)
        
        # Check if we have enough audio
        chunk_size = Config.SAMPLE_RATE * 2 * Config.CHUNK_DURATION
        
        if len(self.audio_buffer) < chunk_size:
            return None
        
        # Get audio to process
        audio_to_process = bytes(self.audio_buffer)
        self.audio_buffer.clear()
        
        # 1. Transcribe
        user_text = await self.transcribe(audio_to_process)
        
        if not user_text:
            return None
        
        # 2. Process through validation state machine
        assistant_text = await self.validation_sm.process(user_text)
        
        # 3. Synthesize
        audio_response = await self.synthesize(assistant_text)
        
        # Log state
        logger.info(f"📊 Context: {json.dumps(self.validation_sm.context.to_dict(), ensure_ascii=False, indent=2)}")
        
        return audio_response


# ============================================================================
# WEBSOCKET SERVER
# ============================================================================

async def handle_client(websocket):
    """Handle WebSocket client connection"""
    client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
    logger.info(f"✅ Client connected: {client_id}")
    
    # Create agent for this client
    agent = VoiceAgentWithValidation()
    
    try:
        # Send greeting
        greeting = await agent.validation_sm.start()
        greeting_audio = await agent.synthesize(greeting)
        if greeting_audio:
            await websocket.send(greeting_audio)
        
        # Process incoming audio
        async for message in websocket:
            #print("111111111111111111\n")
            if isinstance(message, bytes):
                #print("## calling process_audio .... \n")
                response_audio = await agent.process_audio(message)
                
                if response_audio:
                    await websocket.send(response_audio)
                
                # Check if validation complete
                if agent.validation_sm.is_complete():
                    logger.info(f"✅ Validation complete for {client_id}")
                    logger.info(f"Final context: {json.dumps(agent.validation_sm.context.to_dict(), ensure_ascii=False, indent=2)}")
                    # In Phase 2, we would transition to intent detection here
    
    except websockets.exceptions.ConnectionClosed:
        logger.info(f"Client disconnected: {client_id}")
    except Exception as e:
        logger.error(f"Error handling client {client_id}: {e}")
        import traceback
        traceback.print_exc()
    finally:
        logger.info(f"Cleaned up connection: {client_id}")


async def start_server():
    """Start WebSocket server"""
    logger.info(f"Starting WebSocket server on {Config.WS_HOST}:{Config.WS_PORT}")
    
    async with websockets.serve(
        handle_client,
        Config.WS_HOST,
        Config.WS_PORT,
        max_size=10_000_000
    ):
        logger.info(f"✅ Voice agent listening on ws://{Config.WS_HOST}:{Config.WS_PORT}")
        logger.info("Press Ctrl+C to stop")
        await asyncio.Future()


# ============================================================================
# MAIN
# ============================================================================

async def main():
    """Main entry point"""
    logger.info("="*70)
    logger.info("🎙️  ELECTRICITY CALL CENTER - PHASE 1: VALIDATION")
    logger.info("="*70)
    logger.info(f"Language:     Vietnamese")
    logger.info(f"LLM Server:   {Config.VLLM_BASE_URL}")
    logger.info(f"TTS Model:    VieNeu ({Config.VIENEU_MODEL_DIR})")
    logger.info(f"Customer DB:  {Config.CUSTOMER_DB_FILE}")
    logger.info(f"WebSocket:    ws://{Config.WS_HOST}:{Config.WS_PORT}")
    logger.info("="*70)
    logger.info("")
    
    # Start server
    try:
        await start_server()
    except KeyboardInterrupt:
        logger.info("")
        logger.info("Shutting down...")


if __name__ == "__main__":
    asyncio.run(main())
