#!/usr/bin/env python3
"""
Phase 1 + Phase 2: User Validation + Intent Detection
Electricity Company Call Center - Vietnam

Phase 1: User Validation (COMPLETE)
Phase 2: Intent Detection (NEW)
  - After validation complete, detect user intent
  - 4 intents: BILLING_INQUIRY, NEW_REGISTRATION, INCIDENT_REPORT, OUTAGE_INFO
  - Confirm intent with user before proceeding

Phase 3: Intent handlers (TO BE IMPLEMENTED)

Audio segmentation: Silero VAD (replaces fixed CHUNK_DURATION)
"""

import asyncio
import os
import sys
import json
import re
import time
from pathlib import Path
from typing import Optional, Dict, Any, Tuple
from dataclasses import dataclass
from enum import Enum
import ssl
# Core imports
import websockets
from loguru import logger
from dotenv import load_dotenv

# Local model imports
from openai import AsyncOpenAI
import tempfile
import wave
import numpy as np
from scipy import signal
from faster_asr import MyAsrModel
from faster_asr.vad import get_speech_timestamps, VadOptions
from vieneu import Vieneu

# Load environment
load_dotenv()


# ============================================================================
# CONFIGURATION
# ============================================================================

class Config:
    """Configuration for validation system"""

    # ASR
    WHISPER_LANGUAGE = os.getenv("WHISPER_LANGUAGE", "vi")

    # vLLM
    VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1")
    VLLM_MODEL = os.getenv("VLLM_MODEL", "./models/Qwen/Qwen2.5-7B-Instruct")

    # VieNeu TTS
    VIENEU_MODEL_DIR = os.getenv("VIENEU_MODEL_DIR", "vieneu-0.3B")
    VIENEU_VOICE_ID = os.getenv("VIENEU_VOICE_ID", "Ly")

    # WebSocket
    WS_HOST = os.getenv("WS_HOST", "0.0.0.0")
    WS_PORT = int(os.getenv("WS_PORT", "8765"))

    # Audio
    SAMPLE_RATE = 16000

    # VAD settings (replaces fixed CHUNK_DURATION)
    VAD_THRESHOLD = 0.5
    VAD_MIN_SPEECH_MS = 250
    VAD_MIN_SILENCE_MS = 700
    VAD_SPEECH_PAD_MS = 30
    VAD_MAX_BUFFER_SEC = 30
    VAD_MIN_BUFFER_SEC = 0.5
    VAD_NO_SPEECH_TRIM_SEC = 5

    # LLM settings
    MAX_TOKENS = int(os.getenv("MAX_TOKENS", "150"))
    TEMPERATURE = float(os.getenv("TEMPERATURE", "0.3"))

    # Validation settings
    MAX_RETRY_PHONE = 3
    MAX_RETRY_NAME = 3
    PHONE_CONFIDENCE_THRESHOLD = 0.7
    NAME_CONFIDENCE_THRESHOLD = 0.6

    # Intent detection settings (NEW)
    INTENT_CONFIDENCE_THRESHOLD = 0.7
    MAX_RETRY_INTENT = 2

    # Customer database
    CUSTOMER_DB_FILE = os.getenv("CUSTOMER_DB_FILE", "./data/customers.csv")

    # Outage database
    OUTAGE_DB_FILE = os.getenv("OUTAGE_DB_FILE", "./data/outages.csv")

    # Incident database
    INCIDENT_DB_FILE = os.getenv("INCIDENT_DB_FILE", "./data/incidents.csv")


# ============================================================================
# DATA MODELS
# ============================================================================

class ValidationState(Enum):
    """State machine states for validation"""
    GREETING = "GREETING"
    AWAIT_PHONE_REQUEST = "AWAIT_PHONE_REQUEST"
    AWAIT_NAME_REQUEST = "AWAIT_NAME_REQUEST"
    VALIDATING_USER = "VALIDATING_USER"
    VALIDATION_COMPLETE = "VALIDATION_COMPLETE"
    ESCALATE_TO_HUMAN = "ESCALATE_TO_HUMAN"


class IntentType(Enum):
    """Intent types - Phase 2"""
    BILLING_INQUIRY = "BILLING_INQUIRY"
    NEW_REGISTRATION = "NEW_REGISTRATION"
    INCIDENT_REPORT = "INCIDENT_REPORT"
    OUTAGE_INFO = "OUTAGE_INFO"
    UNCLEAR = "UNCLEAR"


class IntentDetectionState(Enum):
    """State machine states for intent detection - Phase 2"""
    AWAITING_INTENT = "AWAITING_INTENT"
    INTENT_CONFIRMED = "INTENT_CONFIRMED"


@dataclass
class ExtractionResult:
    """Result from LLM extraction"""
    success: bool
    value: Optional[str]
    confidence: float
    needs_confirmation: bool
    failure_reason: Optional[str] = None
    user_provided_voluntarily: bool = False


@dataclass
class IntentResult:
    """Result from intent detection - Phase 2"""
    intent: IntentType
    confidence: float
    keywords_found: list[str]
    needs_clarification: bool
    user_description: Optional[str] = None  # What user said in their own words


@dataclass
class ValidationContext:
    """Context maintained throughout validation"""
    phone_number: Optional[str] = None
    phone_confirmed: bool = False
    customer_name: Optional[str] = None
    name_confirmed: bool = False
    database_name: Optional[str] = None
    account_id: Optional[str] = None
    customer_status: Optional[str] = None  # existing_verified | existing_mismatch | new_customer
    retry_count_phone: int = 0
    retry_count_name: int = 0

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for logging"""
        return {
            "phone": self.phone_number,
            "phone_confirmed": self.phone_confirmed,
            "name": self.customer_name,
            "name_confirmed": self.name_confirmed,
            "status": self.customer_status,
            "retries": {
                "phone": self.retry_count_phone,
                "name": self.retry_count_name
            }
        }


@dataclass
class IntentContext:
    """Context for intent detection - Phase 2"""
    detected_intent: Optional[IntentType] = None
    intent_confidence: float = 0.0
    intent_confirmed: bool = False
    user_request_text: Optional[str] = None
    retry_count_intent: int = 0

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for logging"""
        return {
            "intent": self.detected_intent.value if self.detected_intent else None,
            "confidence": self.intent_confidence,
            "confirmed": self.intent_confirmed,
            "request": self.user_request_text,
            "retries": self.retry_count_intent
        }


# ============================================================================
# CUSTOMER DATABASE (from Phase 1)
# ============================================================================

class CustomerDatabase:
    """Simple customer database (CSV-based)"""

    def __init__(self, db_file: str):
        self.db_file = db_file
        self.customers = {}
        self._load_database()

    def _load_database(self):
        """Load customer data from CSV"""
        if not os.path.exists(self.db_file):
            logger.warning(f"Customer database not found: {self.db_file}")
            logger.info("Creating sample database...")
            self._create_sample_db()

        try:
            with open(self.db_file, 'r', encoding='utf-8') as f:
                lines = f.readlines()
                for i, line in enumerate(lines):
                    if i == 0:  # Skip header
                        continue
                    parts = line.strip().split(',')
                    if len(parts) >= 3:
                        phone, name, account_id = parts[0], parts[1], parts[2]
                        self.customers[phone] = {
                            'name': name,
                            'account_id': account_id,
                            'phone': phone,
                            'electricity_charge_vnd': parts[-1] if len(parts) > 3 else None,
                        }
            logger.info(f"Loaded {len(self.customers)} customers from database")
        except Exception as e:
            logger.error(f"Error loading database: {e}")

    def _create_sample_db(self):
        """Create sample database for testing"""
        os.makedirs(os.path.dirname(self.db_file), exist_ok=True)
        sample_data = """phone_number,customer_name,account_id,registration_date
0901234567,Nguyễn Văn An,KH001,2023-01-15
0912345678,Trần Thị Bình,KH002,2023-02-20
0923456789,Lê Văn Công,KH003,2023-03-10
0934567890,Phạm Thị Dung,KH004,2023-04-05
0945678901,Hoàng Văn Em,KH005,2023-05-12
"""
        with open(self.db_file, 'w', encoding='utf-8') as f:
            f.write(sample_data)
        logger.info(f"Created sample database: {self.db_file}")

    def lookup(self, phone: str) -> Optional[Dict[str, str]]:
        """Lookup customer by phone number"""
        return self.customers.get(phone)

    def register(self, phone: str, name: str) -> str:
        """Register a new customer. Returns the new account_id."""
        from datetime import date

        # Generate next account ID
        existing_ids = [
            int(c['account_id'].replace('KH', ''))
            for c in self.customers.values()
            if c['account_id'].startswith('KH')
        ]
        next_id = max(existing_ids, default=0) + 1
        account_id = f"KH{next_id:03d}"
        today = date.today().isoformat()

        # Add to in-memory dict
        self.customers[phone] = {
            'name': name,
            'account_id': account_id,
            'phone': phone,
            'electricity_charge_vnd': None,
        }

        # Append to CSV file
        with open(self.db_file, 'a', encoding='utf-8') as f:
            f.write(f"{phone},{name},{account_id},{today},\n")

        logger.info(f"Registered new customer: {phone} / {name} -> {account_id}")
        return account_id

    def fuzzy_match_name(self, input_name: str, db_name: str) -> float:
        """Fuzzy match Vietnamese names - Returns similarity score 0.0-1.0"""
        def normalize(text: str) -> str:
            text = text.lower()
            text = text.replace('đ', 'd')
            replacements = {
                'á': 'a', 'à': 'a', 'ả': 'a', 'ã': 'a', 'ạ': 'a',
                'ă': 'a', 'ắ': 'a', 'ằ': 'a', 'ẳ': 'a', 'ẵ': 'a', 'ặ': 'a',
                'â': 'a', 'ấ': 'a', 'ầ': 'a', 'ẩ': 'a', 'ẫ': 'a', 'ậ': 'a',
                'é': 'e', 'è': 'e', 'ẻ': 'e', 'ẽ': 'e', 'ẹ': 'e',
                'ê': 'e', 'ế': 'e', 'ề': 'e', 'ể': 'e', 'ễ': 'e', 'ệ': 'e',
                'í': 'i', 'ì': 'i', 'ỉ': 'i', 'ĩ': 'i', 'ị': 'i',
                'ó': 'o', 'ò': 'o', 'ỏ': 'o', 'õ': 'o', 'ọ': 'o',
                'ô': 'o', 'ố': 'o', 'ồ': 'o', 'ổ': 'o', 'ỗ': 'o', 'ộ': 'o',
                'ơ': 'o', 'ớ': 'o', 'ờ': 'o', 'ở': 'o', 'ỡ': 'o', 'ợ': 'o',
                'ú': 'u', 'ù': 'u', 'ủ': 'u', 'ũ': 'u', 'ụ': 'u',
                'ư': 'u', 'ứ': 'u', 'ừ': 'u', 'ử': 'u', 'ữ': 'u', 'ự': 'u',
                'ý': 'y', 'ỳ': 'y', 'ỷ': 'y', 'ỹ': 'y', 'ỵ': 'y',
            }
            for viet, ascii_char in replacements.items():
                text = text.replace(viet, ascii_char)
            return ''.join(text.split())

        norm_input = normalize(input_name)
        norm_db = normalize(db_name)

        if norm_input == norm_db:
            return 1.0
        elif norm_input in norm_db or norm_db in norm_input:
            return 0.85
        else:
            longer = max(len(norm_input), len(norm_db))
            if longer == 0:
                return 0.0
            matches = sum(1 for a, b in zip(norm_input, norm_db) if a == b)
            return matches / longer


# ============================================================================
# OUTAGE DATABASE
# ============================================================================

class OutageDatabase:
    """Simple outage information database (CSV-based)"""

    def __init__(self, db_file: str):
        self.db_file = db_file
        self.outages: list[dict] = []
        self._load_database()

    def _load_database(self):
        """Load outage data from CSV"""
        if not os.path.exists(self.db_file):
            logger.warning(f"Outage database not found: {self.db_file}")
            return

        try:
            with open(self.db_file, 'r', encoding='utf-8') as f:
                lines = f.readlines()
                if not lines:
                    return
                header = lines[0].strip().split(',')
                for line in lines[1:]:
                    parts = line.strip().split(',')
                    if len(parts) >= len(header):
                        row = {header[i]: parts[i] for i in range(len(header))}
                        self.outages.append(row)
            logger.info(f"Loaded {len(self.outages)} outage records")
        except Exception as e:
            logger.error(f"Error loading outage database: {e}")

    def get_all_outages(self) -> list[dict]:
        """Return all current outage records"""
        return self.outages


# ============================================================================
# INCIDENT DATABASE
# ============================================================================

class IncidentDatabase:
    """Append-only incident report database (CSV-based)"""

    def __init__(self, db_file: str):
        self.db_file = db_file
        # Create file with header if it doesn't exist
        if not os.path.exists(db_file):
            os.makedirs(os.path.dirname(db_file), exist_ok=True)
            with open(db_file, 'w', encoding='utf-8') as f:
                f.write("phone,customer_name,locality,description,reported_at\n")

    def save(self, phone: str, name: str, locality: str, description: str):
        """Append an incident report to the CSV."""
        from datetime import datetime
        reported_at = datetime.now().strftime("%Y-%m-%d %H:%M")
        # Escape commas in free-text fields
        locality = locality.replace(',', ' ')
        description = description.replace(',', ' ')
        with open(self.db_file, 'a', encoding='utf-8') as f:
            f.write(f"{phone},{name},{locality},{description},{reported_at}\n")
        logger.info(f"Incident saved: {phone} / {locality} / {description}")


# ============================================================================
# LLM FUNCTIONS (Phase 1 + Phase 2)
# ============================================================================

class LLMFunctions:
    """LLM function calling for data extraction and response generation"""

    def __init__(self, client: AsyncOpenAI):
        self.client = client

    async def _call_llm(self, messages: list, temperature: float = 0.3) -> str:
        """Helper to call LLM"""
        try:
            response = await self.client.chat.completions.create(
                model=Config.VLLM_MODEL,
                messages=messages,
                temperature=temperature,
                max_tokens=Config.MAX_TOKENS
            )
            return response.choices[0].message.content.strip()
        except Exception as e:
            logger.error(f"LLM call error: {e}")
            return ""

    # ========================================================================
    # PHONE EXTRACTION (Phase 1 - with improvements)
    # ========================================================================

    @staticmethod
    def vietnamese_words_to_digits(text: str) -> Optional[str]:
        """Convert Vietnamese spoken numbers to digits"""
        VIET_NUMBERS = {
            'không': '0', 'một': '1', 'mốt': '1', 'hai': '2',
            'ba': '3', 'bốn': '4', 'tư': '4', 'năm': '5',
            'lăm': '5', 'sáu': '6', 'bảy': '7', 'tám': '8', 'chín': '9',
        }

        text = text.lower().strip()
        words = text.split()

        digits = []
        for word in words:
            word_clean = word.strip('.,!?;:')
            if word_clean in VIET_NUMBERS:
                digits.append(VIET_NUMBERS[word_clean])

        return ''.join(digits) if digits else None

    @staticmethod
    def extract_phone_from_text_python(text: str) -> tuple:
        """Pure Python extraction (no LLM)"""
        text = text.lower().strip()

        # Method 1: Direct digits
        digit_pattern = r'0\d{9}'
        digit_match = re.search(digit_pattern, text)
        if digit_match:
            return (digit_match.group(), 0.95, "direct_digits")

        # Method 2: Spaced digits
        spaced_pattern = r'0[\s\-]*\d[\s\-]*\d[\s\-]*\d[\s\-]*\d[\s\-]*\d[\s\-]*\d[\s\-]*\d[\s\-]*\d[\s\-]*\d'
        spaced_match = re.search(spaced_pattern, text)
        if spaced_match:
            phone = re.sub(r'[\s\-]', '', spaced_match.group())
            if len(phone) == 10:
                return (phone, 0.90, "spaced_digits")

        # Method 3: Vietnamese words
        converted = LLMFunctions.vietnamese_words_to_digits(text)
        if converted and len(converted) >= 10:
            phone = converted[:10]
            if phone.startswith('0'):
                confidence = 0.85 if len(converted) == 10 else 0.75
                return (phone, confidence, "vietnamese_words")

        # Method 4: Concatenate all digits
        all_digits = re.findall(r'\d+', text)
        if all_digits:
            combined = ''.join(all_digits)
            if len(combined) >= 10 and combined[0] == '0':
                phone = combined[:10]
                return (phone, 0.70, "concatenated_digits")

        return (None, 0.0, "no_match")

    async def extract_phone_number(self, user_speech: str, context: ValidationContext) -> ExtractionResult:
        """Hybrid phone extraction: Python first, LLM fallback"""
        logger.debug(f"Extracting phone from: '{user_speech}'")

        # STEP 1: Try Python extraction first
        phone, confidence, method = self.extract_phone_from_text_python(user_speech)

        if phone and len(phone) == 10:
            logger.info(f"Python extracted: {phone} (method: {method}, conf: {confidence:.2f})")
            return ExtractionResult(
                success=True,
                value=phone,
                confidence=confidence,
                needs_confirmation=confidence < 0.85,
                failure_reason=None
            )

        # STEP 2: Python failed, try LLM
        logger.info(f"Python failed, using LLM...")

        system_prompt = """Bạn là trợ lý trích xuất số điện thoại.

NHIỆM VỤ: Trích xuất số điện thoại 10 chữ số bắt đầu bằng 0.

CHUYỂN ĐỔI CHỮ SANG SỐ:
không=0, một=1, mốt=1, hai=2, ba=3, bốn=4, tư=4, năm=5, lăm=5, sáu=6, bảy=7, tám=8, chín=9

OUTPUT (JSON):
{"phone_number": "0901234567", "confidence": 0.9, "needs_confirmation": false, "failure_reason": null}

VÍ DỤ:
"không chín không một hai ba bốn năm sáu bảy" → {"phone_number": "0901234567", "confidence": 0.9, "needs_confirmation": false, "failure_reason": null}
"Tôi không nhớ" → {"phone_number": null, "confidence": 0.0, "needs_confirmation": false, "failure_reason": "không nhớ"}
"""

        user_prompt = f"""Khách: "{user_speech}"

Trích xuất số (10 chữ số, bắt đầu 0):"""

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]

        response = await self._call_llm(messages, temperature=0.1)
        logger.debug(f"LLM response: {response}")

        try:
            if "```json" in response:
                response = response.split("```json")[1].split("```")[0].strip()
            elif "```" in response:
                response = response.split("```")[1].split("```")[0].strip()

            data = json.loads(response)
            phone = data.get("phone_number")
            llm_confidence = data.get("confidence", 0.0)

            if phone:
                phone = phone.strip().replace(" ", "").replace("-", "")

                if not phone.isdigit():
                    digits_only = ''.join(c for c in phone if c.isdigit())
                    if len(digits_only) >= 10 and digits_only[0] == '0':
                        phone = digits_only[:10]
                    else:
                        return ExtractionResult(False, None, 0.0, False, "Không hợp lệ")

                if not (phone.startswith("0") and len(phone) == 10):
                    return ExtractionResult(False, None, 0.0, False,
                        f"Định dạng sai (phải 10 số, bắt đầu 0)")

                logger.info(f"LLM extracted: {phone} (conf: {llm_confidence:.2f})")

            return ExtractionResult(
                success=phone is not None,
                value=phone,
                confidence=llm_confidence,
                needs_confirmation=data.get("needs_confirmation", llm_confidence < 0.7),
                failure_reason=data.get("failure_reason")
            )

        except Exception as e:
            logger.error(f"LLM parse error: {e}")
            return ExtractionResult(False, None, 0.0, False, "Lỗi xử lý")

    # ========================================================================
    # NAME EXTRACTION (Phase 1)
    # ========================================================================

    async def extract_name(self, user_speech: str, context: ValidationContext) -> ExtractionResult:
        """Extract customer name from Vietnamese speech"""
        system_prompt = """Bạn là trợ lý trích xuất tên khách hàng.

NHIỆM VỤ: Trích xuất họ và tên của khách hàng.

QUY TẮC:
1. Tên Việt Nam thường có 2-4 từ (VD: Nguyễn Văn An, Trần Thị Bình)
2. Xử lý các từ lắp và lời xưng hô (tôi là, tên tôi là, con/cháu/em tên là)
3. Giữ nguyên dấu tiếng Việt

OUTPUT FORMAT (JSON):
{
  "name": "Nguyễn Văn An" hoặc null,
  "confidence": 0.0-1.0,
  "needs_confirmation": true/false,
  "failure_reason": null
}
"""

        user_prompt = f"""Khách hàng nói: "{user_speech}"

Trích xuất tên:"""

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]

        response = await self._call_llm(messages, temperature=0.1)

        try:
            if "```json" in response:
                response = response.split("```json")[1].split("```")[0].strip()
            elif "```" in response:
                response = response.split("```")[1].split("```")[0].strip()

            data = json.loads(response)
            name = data.get("name")
            confidence = data.get("confidence", 0.0)

            if name:
                name = name.strip()
                if len(name.split()) < 2:
                    return ExtractionResult(False, None, 0.0, False, "Tên không đầy đủ")

            return ExtractionResult(
                success=name is not None,
                value=name,
                confidence=confidence,
                needs_confirmation=data.get("needs_confirmation", confidence < 0.6),
                failure_reason=data.get("failure_reason")
            )

        except Exception as e:
            logger.error(f"Error parsing name: {e}")
            return ExtractionResult(False, None, 0.0, False, "Lỗi xử lý")

    # ========================================================================
    # INTENT DETECTION (Phase 2 - NEW)
    # ========================================================================

    async def detect_intent(self, user_speech: str, validation_context: ValidationContext) -> IntentResult:
        """
        Detect user intent from their request
        4 intents: BILLING_INQUIRY, NEW_REGISTRATION, INCIDENT_REPORT, OUTAGE_INFO
        """

        system_prompt = """Bạn là trợ lý phân loại yêu cầu của khách hàng công ty điện lực.

NHIỆM VỤ: Xác định ý định của khách hàng từ câu nói.

4 LOẠI YÊU CẦU:

1. BILLING_INQUIRY (Tra cứu hóa đơn)
   - Hỏi về số tiền hóa đơn, tiền điện tháng này/trước
   - Tại sao hóa đơn cao/thấp
   - Hạn thanh toán
   Từ khóa: hóa đơn, tiền điện, chi phí, bao nhiêu, tháng này, tháng trước

2. NEW_REGISTRATION (Đăng ký mới)
   - Đăng ký điện mới
   - Lắp điện
   - Tăng công suất
   - Thay đổi hợp đồng
   Từ khóa: đăng ký mới, lắp điện, tăng công suất, đổi hợp đồng

3. INCIDENT_REPORT (Báo sự cố)
   - Báo mất điện
   - Điện yếu, chập chờn
   - Sự cố điện
   Từ khóa: mất điện, cúp điện, chập chờn, điện yếu, sự cố

4. OUTAGE_INFO (Kiểm tra thông tin mất điện)
   - Hỏi khi nào có điện trở lại
   - Kiểm tra khu vực có mất điện không
   - Lịch bảo trì
   Từ khóa: khi nào có điện, bao giờ, khu vực mất điện, lịch bảo trì

KHÁCH HÀNG: {customer_status}
Tài khoản: {account_id}

OUTPUT (JSON):
{{
  "intent": "BILLING_INQUIRY" hoặc "NEW_REGISTRATION" hoặc "INCIDENT_REPORT" hoặc "OUTAGE_INFO" hoặc "UNCLEAR",
  "confidence": 0.0-1.0,
  "keywords_found": ["từ khóa 1", "từ khóa 2"],
  "needs_clarification": true/false,
  "user_description": "tóm tắt yêu cầu bằng tiếng Việt"
}}

VÍ DỤ:

Input: "Hóa đơn tháng này của tôi bao nhiêu tiền?"
Output: {{"intent": "BILLING_INQUIRY", "confidence": 0.95, "keywords_found": ["hóa đơn", "tháng này", "bao nhiêu"], "needs_clarification": false, "user_description": "tra cứu hóa đơn tháng này"}}

Input: "Nhà tôi mất điện rồi"
Output: {{"intent": "INCIDENT_REPORT", "confidence": 0.90, "keywords_found": ["mất điện"], "needs_clarification": false, "user_description": "báo mất điện"}}

Input: "Tôi muốn đăng ký điện mới"
Output: {{"intent": "NEW_REGISTRATION", "confidence": 0.95, "keywords_found": ["đăng ký", "mới"], "needs_clarification": false, "user_description": "đăng ký dịch vụ điện mới"}}

Input: "Khi nào có điện trở lại?"
Output: {{"intent": "OUTAGE_INFO", "confidence": 0.90, "keywords_found": ["khi nào", "có điện"], "needs_clarification": false, "user_description": "hỏi thời gian có điện trở lại"}}

Input: "Tôi cần hỗ trợ"
Output: {{"intent": "UNCLEAR", "confidence": 0.3, "keywords_found": [], "needs_clarification": true, "user_description": "yêu cầu chung, cần làm rõ"}}
"""

        # Fill in customer context
        customer_status = validation_context.customer_status or "unknown"
        account_id = validation_context.account_id or "N/A"

        system_prompt = system_prompt.format(
            customer_status=customer_status,
            account_id=account_id
        )

        user_prompt = f"""Khách hàng nói: "{user_speech}"

Phân loại yêu cầu và trả về JSON:"""

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]

        response = await self._call_llm(messages, temperature=0.2)
        logger.debug(f"Intent detection response: {response}")

        try:
            # Parse JSON
            if "```json" in response:
                response = response.split("```json")[1].split("```")[0].strip()
            elif "```" in response:
                response = response.split("```")[1].split("```")[0].strip()

            data = json.loads(response)

            intent_str = data.get("intent", "UNCLEAR")

            # Map string to IntentType enum
            try:
                intent = IntentType[intent_str]
            except KeyError:
                logger.warning(f"Unknown intent: {intent_str}, defaulting to UNCLEAR")
                intent = IntentType.UNCLEAR

            confidence = data.get("confidence", 0.0)
            keywords = data.get("keywords_found", [])
            needs_clarification = data.get("needs_clarification", confidence < Config.INTENT_CONFIDENCE_THRESHOLD)
            user_description = data.get("user_description", user_speech[:50])

            logger.info(f"Detected intent: {intent.value} (conf: {confidence:.2f})")

            return IntentResult(
                intent=intent,
                confidence=confidence,
                keywords_found=keywords,
                needs_clarification=needs_clarification,
                user_description=user_description
            )

        except Exception as e:
            logger.error(f"Error parsing intent: {e}, Response: {response}")
            return IntentResult(
                intent=IntentType.UNCLEAR,
                confidence=0.0,
                keywords_found=[],
                needs_clarification=True,
                user_description=user_speech[:50]
            )

    # ========================================================================
    # INCIDENT INFO EXTRACTION (Phase 3)
    # ========================================================================

    async def extract_incident_info(self, user_speech: str) -> Dict[str, Optional[str]]:
        """Extract locality and description from user's incident report speech."""
        system_prompt = """Bạn là trợ lý trích xuất thông tin sự cố điện.

NHIỆM VỤ: Trích xuất ĐỊA ĐIỂM và MÔ TẢ SỰ CỐ từ câu nói của khách hàng.

QUY TẮC:
1. locality: Khu vực/địa chỉ xảy ra sự cố (quận, phường, đường, số nhà...)
2. description: Mô tả sự cố (mất điện, điện yếu, chập chờn, cháy nổ...)
3. Nếu không tìm thấy thông tin nào, trả về null

OUTPUT (JSON):
{
  "locality": "Quận 7" hoặc null,
  "description": "Mất điện toàn bộ khu vực" hoặc null
}

VÍ DỤ:
"Nhà tôi ở quận 7 bị mất điện từ sáng" → {"locality": "Quận 7", "description": "Mất điện từ sáng"}
"Khu vực phường Tân Phú điện chập chờn" → {"locality": "Phường Tân Phú", "description": "Điện chập chờn"}
"Mất điện rồi" → {"locality": null, "description": "Mất điện"}
"""

        user_prompt = f"""Khách hàng nói: "{user_speech}"

Trích xuất địa điểm và mô tả sự cố:"""

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]

        response = await self._call_llm(messages, temperature=0.1)
        logger.debug(f"Incident extraction response: {response}")

        try:
            if "```json" in response:
                response = response.split("```json")[1].split("```")[0].strip()
            elif "```" in response:
                response = response.split("```")[1].split("```")[0].strip()

            data = json.loads(response)
            locality = data.get("locality")
            description = data.get("description")

            logger.info(f"Incident extracted: locality={locality}, description={description}")
            return {"locality": locality, "description": description}

        except Exception as e:
            logger.error(f"Error parsing incident info: {e}")
            return {"locality": None, "description": None}

    # ========================================================================
    # RESPONSE GENERATION (Phase 1 + Phase 2)
    # ========================================================================

    async def generate_response(self, state, context, extra_info: Optional[Dict[str, Any]] = None) -> str:
        """
        Generate natural Vietnamese response based on state and context
        Handles both ValidationState and IntentDetectionState
        """

        # Check if this is validation state or intent detection state
        if isinstance(state, ValidationState):
            return await self._generate_validation_response(state, context, extra_info)
        elif isinstance(state, IntentDetectionState):
            return await self._generate_intent_response(state, context, extra_info)
        else:
            return "Xin lỗi, có lỗi xảy ra."

    async def _generate_validation_response(self, state: ValidationState,
                                           context: ValidationContext,
                                           extra_info: Optional[Dict[str, Any]]) -> str:
        """Generate response for validation states"""

        state_templates = {
            ValidationState.GREETING: "Chào bạn tới với tổng đài điện lực.Để phục vụ quý khách tốt hơn, xin cho biết số điện thoại của quý khách ạ?",
            ValidationState.AWAIT_PHONE_REQUEST: "Để phục vụ quý khách tốt hơn, xin cho biết số điện thoại của quý khách ạ?",
            ValidationState.AWAIT_NAME_REQUEST: "Xin cho biết họ và tên của quý khách ạ?",
            ValidationState.VALIDATION_COMPLETE: self._get_validation_complete_template(context)
        }

        # Handle retry scenarios
        if extra_info and extra_info.get("retry"):
            retry_count = extra_info.get("retry_count", 0)
            if state == ValidationState.AWAIT_PHONE_REQUEST:
                if retry_count == 1:
                    return "Xin lỗi, tôi chưa nghe rõ số điện thoại. Quý khách vui lòng nói lại số điện thoại ạ?"
                elif retry_count >= 2:
                    return "Tôi vẫn chưa nghe rõ số. Quý khách có thể nói từng số một được không ạ?"

            elif state == ValidationState.AWAIT_NAME_REQUEST:
                if retry_count == 1:
                    return "Xin lỗi, tôi chưa nghe rõ tên. Quý khách vui lòng nói lại họ và tên ạ?"
                elif retry_count >= 2:
                    return "Quý khách vui lòng nói chậm và rõ họ tên được không ạ?"

        return state_templates.get(state, "")

    async def _generate_intent_response(self, state: IntentDetectionState,
                                       context: IntentContext,
                                       extra_info: Optional[Dict[str, Any]]) -> str:
        """Generate response for intent detection states - Phase 2"""

        if state == IntentDetectionState.AWAITING_INTENT:
            # This shouldn't be called - agent prompts directly
            return ""

        elif state == IntentDetectionState.INTENT_CONFIRMED:
            return "Được ạ, để tôi hỗ trợ quý khách."

        else:
            # Need clarification
            if extra_info and extra_info.get("retry"):
                return "Xin lỗi, tôi chưa hiểu rõ yêu cầu. Quý khách có thể nói rõ hơn được không? Quý khách muốn tra cứu hóa đơn, đăng ký mới, báo sự cố, hay kiểm tra mất điện ạ?"
            else:
                return "Quý khách muốn tra cứu hóa đơn, đăng ký dịch vụ mới, báo cáo sự cố, hay kiểm tra thông tin mất điện ạ?"

    def _get_validation_complete_template(self, context: ValidationContext) -> str:
        """Get response template for validation complete state"""
        if context.customer_status == "existing_verified":
            return f"Xin chào anh/chị {context.customer_name}. Hệ thống đã xác nhận thông tin. Quý khách cần hỗ trợ gì ạ?"

        elif context.customer_status == "existing_mismatch":
            return f"Số điện thoại này đã đăng ký với tên {context.database_name}. Quý khách có phải là người đăng ký không ạ?"

        elif context.customer_status == "new_customer":
            return "Hệ thống chưa có thông tin của quý khách. Quý khách cần hỗ trợ gì ạ?"

        return "Đã xác nhận thông tin. Quý khách cần hỗ trợ gì ạ?"


# ============================================================================
# VALIDATION STATE MACHINE (Phase 1)
# ============================================================================

class ValidationStateMachine:
    """State machine for Phase 1: User Validation"""

    def __init__(self, llm_client: AsyncOpenAI, customer_db: CustomerDatabase):
        self.state = ValidationState.GREETING
        self.context = ValidationContext()
        self.llm_functions = LLMFunctions(llm_client)
        self.customer_db = customer_db
        self.pending_confirmation = None

        logger.info("Validation state machine initialized")

    def is_complete(self) -> bool:
        """Check if validation is complete"""
        return self.state == ValidationState.VALIDATION_COMPLETE

    def get_context(self) -> ValidationContext:
        """Get current validation context"""
        return self.context

    async def start(self) -> str:
        """Start the validation flow"""
        self.state = ValidationState.GREETING
        response = await self.llm_functions.generate_response(self.state, self.context, None)
        self.state = ValidationState.AWAIT_PHONE_REQUEST
        return response

    async def process(self, user_speech: str) -> str:
        """Process user input based on current state"""
        logger.info(f"State: {self.state.value} | User: {user_speech}")

        if self.state == ValidationState.AWAIT_PHONE_REQUEST:
            return await self._handle_await_phone(user_speech)

        elif self.state == ValidationState.AWAIT_NAME_REQUEST:
            return await self._handle_await_name(user_speech)

        elif self.state == ValidationState.VALIDATION_COMPLETE:
            # This shouldn't be called - control passed to intent detection
            return ""
        else:
            logger.error(f"Unknown state: {self.state}")
            return "Xin lỗi, có lỗi xảy ra."

    async def _handle_await_phone(self, user_speech: str) -> str:
        """Handle AWAIT_PHONE_REQUEST state"""
        result = await self.llm_functions.extract_phone_number(user_speech, self.context)

        if result.success and result.value:
            self.context.phone_number = result.value
            logger.info(f"Extracted phone: {result.value} (confidence: {result.confidence})")

            if result.confidence > 0.9 and not result.needs_confirmation:
                self.context.phone_confirmed = True
                self.state = ValidationState.AWAIT_NAME_REQUEST
                return await self.llm_functions.generate_response(self.state, self.context, None)
            else:
                self.state = ValidationState.AWAIT_NAME_REQUEST
                self.pending_confirmation = "phone"
                return await self.llm_functions.generate_response(self.state, self.context, None)
        else:
            self.context.retry_count_phone += 1
            logger.warning(f"Phone extraction failed (attempt {self.context.retry_count_phone})")

            if self.context.retry_count_phone >= Config.MAX_RETRY_PHONE:
                self.state = ValidationState.ESCALATE_TO_HUMAN
                return "Xin lỗi quý khách. Để được hỗ trợ tốt hơn, tôi sẽ chuyển máy cho nhân viên."

            return await self.llm_functions.generate_response(
                self.state,
                self.context,
                {"retry": True, "retry_count": self.context.retry_count_phone}
            )

    async def _handle_await_name(self, user_speech: str) -> str:
        """Handle AWAIT_NAME_REQUEST state"""
        result = await self.llm_functions.extract_name(user_speech, self.context)

        if result.success and result.value:
            self.context.customer_name = result.value
            logger.info(f"Extracted name: {result.value} (confidence: {result.confidence})")

            if result.confidence > 0.85 and not result.needs_confirmation:
                self.context.name_confirmed = True
                self.state = ValidationState.VALIDATING_USER
                return await self._validate_user()
            else:
                self.state = ValidationState.VALIDATING_USER
                self.pending_confirmation = "name"
                return await self.llm_functions.generate_response(self.state, self.context, None)
        else:
            self.context.retry_count_name += 1
            logger.warning(f"Name extraction failed (attempt {self.context.retry_count_name})")

            if self.context.retry_count_name >= Config.MAX_RETRY_NAME:
                self.state = ValidationState.ESCALATE_TO_HUMAN
                return "Xin lỗi quý khách. Tôi sẽ chuyển máy cho nhân viên."

            return await self.llm_functions.generate_response(
                self.state,
                self.context,
                {"retry": True, "retry_count": self.context.retry_count_name}
            )

    async def _validate_user(self) -> str:
        """Validate user against database"""
        logger.info(f"Validating: {self.context.phone_number} | {self.context.customer_name}")

        db_customer = self.customer_db.lookup(self.context.phone_number)

        if db_customer:
            self.context.database_name = db_customer['name']
            self.context.account_id = db_customer['account_id']

            similarity = self.customer_db.fuzzy_match_name(
                self.context.customer_name,
                db_customer['name']
            )

            logger.info(f"Name similarity: {similarity:.2f}")

            if similarity > 0.8:
                self.context.customer_status = "existing_verified"
                logger.info(f"Customer verified: {db_customer['account_id']}")
            else:
                self.context.customer_status = "existing_mismatch"
                logger.warning(f"Name mismatch")
        else:
            self.context.customer_status = "new_customer"
            logger.info(f"New customer")

        self.state = ValidationState.VALIDATION_COMPLETE

        response = await self.llm_functions.generate_response(self.state, self.context, None)
        return response


# ============================================================================
# INTENT DETECTION STATE MACHINE (Phase 2 - NEW)
# ============================================================================

class IntentDetectionStateMachine:
    """State machine for Phase 2: Intent Detection"""

    def __init__(self, llm_functions: LLMFunctions, validation_context: ValidationContext):
        self.state = IntentDetectionState.AWAITING_INTENT
        self.context = IntentContext()
        self.llm_functions = llm_functions
        self.validation_context = validation_context

        logger.info("Intent detection state machine initialized")

    def is_complete(self) -> bool:
        """Check if intent is confirmed"""
        return self.state == IntentDetectionState.INTENT_CONFIRMED

    def get_context(self) -> IntentContext:
        """Get current intent context"""
        return self.context

    async def process(self, user_speech: str) -> str:
        """Process user input for intent detection"""
        logger.info(f"Intent State: {self.state.value} | User: {user_speech}")

        if self.state == IntentDetectionState.AWAITING_INTENT:
            return await self._handle_awaiting_intent(user_speech)

        elif self.state == IntentDetectionState.INTENT_CONFIRMED:
            # Intent confirmed, ready for Phase 3
            return ""

        else:
            logger.error(f"Unknown intent state: {self.state}")
            return "Xin lỗi, có lỗi xảy ra."

    async def _handle_awaiting_intent(self, user_speech: str) -> str:
        """Handle AWAITING_INTENT state"""
        # Detect intent using LLM
        result = await self.llm_functions.detect_intent(user_speech, self.validation_context)

        if result.intent != IntentType.UNCLEAR and result.confidence >= Config.INTENT_CONFIDENCE_THRESHOLD:
            # Intent detected with high confidence
            self.context.detected_intent = result.intent
            self.context.intent_confidence = result.confidence
            self.context.user_request_text = result.user_description

            logger.info(f"Intent detected: {result.intent.value} (conf: {result.confidence:.2f})")

            # Move to confirmation
            self.state = IntentDetectionState.INTENT_CONFIRMED
            return await self.llm_functions.generate_response(self.state, self.context, None)

        else:
            # Intent unclear, ask for clarification
            self.context.retry_count_intent += 1
            logger.warning(f"Intent unclear (attempt {self.context.retry_count_intent})")

            if self.context.retry_count_intent >= Config.MAX_RETRY_INTENT:
                # Give up, escalate
                return "Xin lỗi, tôi chưa hiểu rõ yêu cầu của quý khách. Để được hỗ trợ tốt hơn, tôi sẽ chuyển máy cho nhân viên."

            # Ask for clarification
            return await self.llm_functions.generate_response(
                self.state,
                self.context,
                {"retry": True, "retry_count": self.context.retry_count_intent}
            )

# ============================================================================
# VOICE AGENT (INTEGRATED PHASE 1 + PHASE 2) - WITH SILERO VAD
# ============================================================================

class VoiceAgentWithValidationAndIntent:
    """Voice agent with Phase 1 (Validation) + Phase 2 (Intent Detection)

    Uses Silero VAD for speech endpoint detection instead of fixed chunk duration.
    """

    def __init__(self):
        # Initialize ASR
        logger.info(f"Loading ASR model...")
        self.asr_model = MyAsrModel(
            "model_ct2_fp16",
            device="cuda",
            compute_type="float16"
        )
        logger.info("ASR loaded")

        # Initialize LLM client
        logger.info(f"Connecting to vLLM: {Config.VLLM_BASE_URL}")
        self.llm_client = AsyncOpenAI(
            base_url=Config.VLLM_BASE_URL,
            api_key="EMPTY"
        )
        logger.info("LLM client connected")

        # Initialize TTS
        logger.info(f"Loading VieNeu TTS from: {Config.VIENEU_MODEL_DIR}")
        self.tts = Vieneu(Config.VIENEU_MODEL_DIR)
        self.voice_data = None
        if Config.VIENEU_VOICE_ID:
            self.voice_data = self.tts.get_preset_voice(Config.VIENEU_VOICE_ID)
            logger.info(f"VieNeu loaded with voice: {Config.VIENEU_VOICE_ID}")
        else:
            logger.info("VieNeu loaded with default voice")

        # Initialize databases
        self.customer_db = CustomerDatabase(Config.CUSTOMER_DB_FILE)
        self.outage_db = OutageDatabase(Config.OUTAGE_DB_FILE)
        self.incident_db = IncidentDatabase(Config.INCIDENT_DB_FILE)

        # Initialize validation state machine (Phase 1)
        self.validation_sm = ValidationStateMachine(self.llm_client, self.customer_db)

        # Intent detection state machine (Phase 2) - initialized after validation
        self.intent_sm = None

        # Track which phase we're in
        # VALIDATION -> INTENT_DETECTION -> INTENT_HANDLING/INCIDENT_COLLECTING
        #   -> CONVERSATION_ENDING -> ENDED (or back to INTENT_DETECTION)
        self.phase = "VALIDATION"

        # Audio buffer
        self.audio_buffer = bytearray()

        # Echo suppression: after sending TTS, ignore audio for the playback duration
        self._suppress_until = 0.0  # timestamp until which VAD triggers are ignored

        # VAD options (Silero VAD)
        self.vad_options = VadOptions(
            threshold=Config.VAD_THRESHOLD,
            min_speech_duration_ms=Config.VAD_MIN_SPEECH_MS,
            min_silence_duration_ms=Config.VAD_MIN_SILENCE_MS,
            speech_pad_ms=Config.VAD_SPEECH_PAD_MS,
            max_speech_duration_s=Config.VAD_MAX_BUFFER_SEC,
        )

        logger.info("Voice agent initialized (Phase 1 + Phase 2, Silero VAD)")

    def __del__(self):
        """Cleanup"""
        try:
            if hasattr(self, 'tts'):
                self.tts.close()
                logger.info("TTS closed")
        except Exception as e:
            logger.error(f"Error closing TTS: {e}")

    async def transcribe(self, audio_bytes: bytes) -> str:
        """Transcribe audio to text"""
        try:
            audio_array = np.frombuffer(audio_bytes, dtype=np.int16)
            audio_float = audio_array.astype(np.float32) / 32768.0

            segments, info = self.asr_model.transcribe(
                audio_float,
                language="vi",
                vad_filter=False
            )

            text = " ".join([seg.text for seg in segments]).strip()

            if len(text) < 1:
                logger.debug(f"Too short: '{text}'")
                return ""

            if len(text.split()) < 1:
                logger.debug(f"Too few words: '{text}'")
                return ""

            logger.info(f"Transcribed: {text}")
            return text

        except Exception as e:
            logger.error(f"Transcription error: {e}")
            return ""

    async def synthesize(self, text: str) -> bytes:
        """Synthesize speech with VieNeu TTS"""
        try:
            if not text or len(text.strip()) == 0:
                logger.warning("Empty text for TTS")
                return b''

            text = text.strip()
            if len(text) > 500:
                logger.warning(f"Text too long ({len(text)} chars), truncating")
                text = text[:500]

            text = ' '.join(text.split())

            logger.debug(f"Synthesizing: {text}")

            audio_spec = self.tts.infer(text=text, voice=self.voice_data)

            audio_file = tempfile.mktemp(suffix='.wav')
            self.tts.save(audio_spec, audio_file)

            if not os.path.exists(audio_file):
                logger.error("VieNeu didn't create audio file")
                return b''

            if os.path.getsize(audio_file) < 100:
                logger.error(f"Audio file too small")
                os.unlink(audio_file)
                return b''

            # Read audio
            try:
                with wave.open(audio_file, 'rb') as wav:
                    vieneu_sample_rate = wav.getframerate()
                    audio_data = wav.readframes(wav.getnframes())
            except wave.Error as e:
                logger.error(f"WAV read error: {e}")
                os.unlink(audio_file)
                return b''

            # Resample if needed
            if vieneu_sample_rate != Config.SAMPLE_RATE:
                audio_array = np.frombuffer(audio_data, dtype=np.int16)
                num_samples = int(len(audio_array) * Config.SAMPLE_RATE / vieneu_sample_rate)
                resampled = signal.resample(audio_array, num_samples)
                audio_data = resampled.astype(np.int16).tobytes()

            os.unlink(audio_file)

            logger.debug(f"Synthesized {len(audio_data)} bytes")
            return audio_data

        except Exception as e:
            logger.error(f"TTS error: {e}")
            try:
                if 'audio_file' in locals() and os.path.exists(audio_file):
                    os.unlink(audio_file)
            except:
                pass
            return b''

    async def process_audio(self, audio_chunk: bytes) -> Optional[bytes]:
        """Process incoming audio chunk using Silero VAD for endpoint detection.

        Accumulates audio in a buffer. On each chunk, runs VAD on the full buffer.
        When VAD detects speech has ended (trailing silence >= min_silence_duration_ms),
        extracts speech segments and processes through ASR -> State Machine -> TTS.
        """
        # Echo suppression: discard audio arriving while TTS is playing
        if time.time() < self._suppress_until:
            # TTS is still playing on client side — drop this audio chunk
            return None

        # If suppression just ended, clear stale buffer (contains echo)
        if self.audio_buffer and self._suppress_until > 0:
            elapsed_since_suppress = time.time() - self._suppress_until
            if elapsed_since_suppress < 0.5:  # within 0.5s of suppression ending
                self.audio_buffer.clear()
                self._suppress_until = 0.0
                logger.debug("Echo suppression ended, cleared stale audio buffer")

        self.audio_buffer.extend(audio_chunk)

        buffer_samples = len(self.audio_buffer) // 2  # int16 = 2 bytes per sample

        # Skip if buffer too small for meaningful VAD
        if buffer_samples < int(Config.SAMPLE_RATE * Config.VAD_MIN_BUFFER_SEC):
            return None

        # Convert to float32 for VAD
        audio_float = (
            np.frombuffer(bytes(self.audio_buffer), dtype=np.int16)
            .astype(np.float32) / 32768.0
        )

        # Force process if buffer exceeds max duration
        if len(audio_float) >= Config.SAMPLE_RATE * Config.VAD_MAX_BUFFER_SEC:
            logger.warning(f"Buffer exceeded {Config.VAD_MAX_BUFFER_SEC}s, force-processing")
            audio_to_process = bytes(self.audio_buffer)
            self.audio_buffer.clear()
            return await self._process_speech(audio_to_process)

        # Run VAD on entire buffer
        speech_timestamps = get_speech_timestamps(
            audio_float, vad_options=self.vad_options
        )

        if not speech_timestamps:
            # No speech detected — trim buffer if too large
            if len(audio_float) > Config.SAMPLE_RATE * Config.VAD_NO_SPEECH_TRIM_SEC:
                keep_bytes = int(Config.SAMPLE_RATE * 2 * 2)  # keep last 2 seconds
                self.audio_buffer = bytearray(self.audio_buffer[-keep_bytes:])
                logger.debug(f"No speech, trimmed buffer to {keep_bytes} bytes (2s)")
            return None

        # Check if speech has ended (last segment doesn't reach buffer end)
        last_end = speech_timestamps[-1]["end"]
        if last_end < len(audio_float):
            # Speech ended! Send the original raw buffer to ASR
            # (not the extracted segments — Whisper needs acoustic context)
            audio_to_process = bytes(self.audio_buffer)
            self.audio_buffer.clear()

            # Log speech stats
            speech_duration = sum(
                (ts["end"] - ts["start"]) for ts in speech_timestamps
            ) / Config.SAMPLE_RATE
            logger.info(
                f"VAD: {len(speech_timestamps)} segment(s), "
                f"{speech_duration:.1f}s speech in {len(audio_float)/Config.SAMPLE_RATE:.1f}s buffer"
            )

            return await self._process_speech(audio_to_process)

        # Still speaking, keep accumulating
        return None

    async def _process_speech(self, audio_bytes: bytes) -> Optional[bytes]:
        """Process detected speech through ASR -> State Machine -> TTS pipeline."""
        # 1. Transcribe
        user_text = await self.transcribe(audio_bytes)

        if not user_text:
            return None

        # Filter garbage transcriptions (echo artifacts, noise)
        clean_text = user_text.strip().strip('.,!?;:')
        if len(clean_text) < 2 or (len(clean_text.split()) < 2 and len(clean_text) < 5):
            logger.debug(f"Discarding short/garbage transcription: '{user_text}'")
            return None

        # 2. Process based on current phase
        assistant_text = None

        if self.phase == "VALIDATION":
            # Phase 1: Validation
            assistant_text = await self.validation_sm.process(user_text)

            # Check if validation complete
            if self.validation_sm.is_complete():
                logger.info(f"Phase 1 complete, transitioning to Phase 2")
                logger.info(f"Validation context: {json.dumps(self.validation_sm.context.to_dict(), ensure_ascii=False)}")

                # Initialize intent detection (Phase 2)
                self.intent_sm = IntentDetectionStateMachine(
                    self.validation_sm.llm_functions,
                    self.validation_sm.get_context()
                )
                self.phase = "INTENT_DETECTION"

        elif self.phase == "INTENT_DETECTION":
            # Phase 2: Intent Detection
            assistant_text = await self.intent_sm.process(user_text)

            # Check if intent confirmed
            if self.intent_sm.is_complete():
                logger.info(f"Phase 2 complete, intent confirmed")
                logger.info(f"Intent context: {json.dumps(self.intent_sm.context.to_dict(), ensure_ascii=False)}")

                confirmed_intent = self.intent_sm.get_context().detected_intent

                # Handle intents immediately
                if confirmed_intent == IntentType.BILLING_INQUIRY:
                    assistant_text += " " + self._handle_billing_inquiry()
                    self.phase = "CONVERSATION_ENDING"
                elif confirmed_intent == IntentType.OUTAGE_INFO:
                    assistant_text += " " + self._handle_outage_info()
                    self.phase = "CONVERSATION_ENDING"
                elif confirmed_intent == IntentType.NEW_REGISTRATION:
                    assistant_text += " " + self._handle_new_registration()
                    self.phase = "CONVERSATION_ENDING"
                elif confirmed_intent == IntentType.INCIDENT_REPORT:
                    assistant_text += " " + self._handle_incident_report_start()
                    # phase set to INCIDENT_COLLECTING inside _handle_incident_report_start

        elif self.phase == "INCIDENT_COLLECTING":
            # Collecting incident details from user
            assistant_text = await self._process_incident_details(user_text)

        elif self.phase == "CONVERSATION_ENDING":
            # User was asked "cần hỗ trợ thêm gì không?"
            assistant_text = self._handle_conversation_ending(user_text)

            # If user had a new request, phase is now INTENT_DETECTION
            # Process the speech through intent detection immediately
            if self.phase == "INTENT_DETECTION" and not assistant_text:
                assistant_text = await self.intent_sm.process(user_text)
                if self.intent_sm.is_complete():
                    self.phase = "INTENT_HANDLING"

        elif self.phase == "ENDED":
            # Conversation ended, ignore further input
            return None

        else:
            assistant_text = "Có lỗi xảy ra trong hệ thống."

        # 3. Synthesize
        if assistant_text:
            audio_response = await self.synthesize(assistant_text)
            return audio_response

        return None

    def _handle_billing_inquiry(self) -> str:
        """Handle BILLING_INQUIRY intent — look up last month's charge."""
        ctx = self.validation_sm.get_context()

        if ctx.customer_status != "existing_verified":
            return ("Xin lỗi, hệ thống chưa có thông tin tài khoản của quý khách. "
                    "Quý khách vui lòng liên hệ nhân viên để được hỗ trợ. "
                    "Quý khách cần hỗ trợ thêm gì không ạ?")

        customer = self.customer_db.lookup(ctx.phone_number)
        charge = customer.get("electricity_charge_vnd") if customer else None

        if not charge:
            return ("Xin lỗi, hệ thống không tìm thấy thông tin hóa đơn của quý khách. "
                    "Quý khách cần hỗ trợ thêm gì không ạ?")

        # Format charge with thousands separator (e.g. 285000 -> 285.000)
        try:
            charge_formatted = f"{int(charge):,}".replace(",", ".")
        except ValueError:
            charge_formatted = charge

        logger.info(f"Billing inquiry: {ctx.phone_number} -> {charge_formatted} VND")

        return (f"Tiền điện tháng trước của quý khách {ctx.customer_name}, "
                f"mã khách hàng {ctx.account_id}, là {charge_formatted} đồng. "
                f"Quý khách cần hỗ trợ thêm gì không ạ?")

    def _handle_outage_info(self) -> str:
        """Handle OUTAGE_INFO intent — list current outage areas."""
        outages = self.outage_db.get_all_outages()

        if not outages:
            return ("Hiện tại không có thông tin mất điện nào trong hệ thống. "
                    "Quý khách cần hỗ trợ thêm gì không ạ?")

        lines = ["Thông tin mất điện hiện tại:"]
        for o in outages:
            locality = o.get("locality", "Không rõ")
            reason = o.get("reason", "Không rõ nguyên nhân")
            restore = o.get("expected_restore", "chưa xác định")
            lines.append(
                f"Khu vực {locality}, nguyên nhân {reason}, "
                f"dự kiến có điện lại lúc {restore}."
            )

        logger.info(f"Outage info: {len(outages)} records returned")
        return " ".join(lines) + " Quý khách cần hỗ trợ thêm gì không ạ?"

    def _handle_new_registration(self) -> str:
        """Handle NEW_REGISTRATION intent — register new customer."""
        ctx = self.validation_sm.get_context()

        if ctx.customer_status != "new_customer":
            return (f"Quý khách {ctx.customer_name} đã có tài khoản "
                    f"mã {ctx.account_id} trong hệ thống rồi ạ. "
                    f"Quý khách cần hỗ trợ thêm gì không ạ?")

        account_id = self.customer_db.register(ctx.phone_number, ctx.customer_name)

        return (f"Đã đăng ký thành công cho quý khách {ctx.customer_name}, "
                f"số điện thoại {ctx.phone_number}. "
                f"Mã khách hàng mới của quý khách là {account_id}. "
                f"Quý khách cần hỗ trợ thêm gì không ạ?")

    def _handle_conversation_ending(self, user_text: str) -> str:
        """Handle CONVERSATION_ENDING phase.

        If user says no / nothing else, end gracefully.
        If user has a new request, loop back to intent detection.
        """
        user_lower = user_text.lower()

        # Check if user wants to end (Vietnamese + English)
        end_phrases = [
            # Vietnamese
            'không', 'không cần', 'hết rồi', 'thôi', 'không có',
            'vậy thôi', 'được rồi', 'cảm ơn', 'tạm biệt', 'chào',
            # English (ASR sometimes outputs English)
            'no', 'bye', 'goodbye', 'thank you', 'thanks', 'that\'s all',
            'nothing', 'no thank', 'don\'t', 'close',
        ]
        wants_end = any(phrase in user_lower for phrase in end_phrases)

        if wants_end:
            self.phase = "ENDED"
            ctx = self.validation_sm.get_context()
            name = ctx.customer_name or "quý khách"
            return (f"Cảm ơn {name} đã liên hệ tổng đài điện lực. "
                    f"Chúc quý khách một ngày tốt lành. Tạm biệt ạ!")

        # User has another request — loop back to intent detection
        logger.info("User has another request, restarting intent detection")
        self.intent_sm = IntentDetectionStateMachine(
            self.validation_sm.llm_functions,
            self.validation_sm.get_context()
        )
        self.phase = "INTENT_DETECTION"
        # Process the new request through intent detection immediately
        # so user doesn't have to repeat themselves
        return None  # will be handled below

    def _handle_incident_report_start(self) -> str:
        """Handle INCIDENT_REPORT intent — ask user for incident details."""
        self.phase = "INCIDENT_COLLECTING"
        self.incident_info = {"locality": None, "description": None}
        self.incident_retry_count = 0
        return ("Quý khách vui lòng cho biết địa điểm xảy ra sự cố "
                "và mô tả tình trạng cụ thể ạ? "
                "Ví dụ: khu vực nào, mất điện hay điện yếu.")

    async def _process_incident_details(self, user_text: str) -> str:
        """Extract incident info from user speech and save to database.

        Retries up to 2 times if extraction fails. After max retries,
        gives up gracefully and moves to CONVERSATION_ENDING phase.
        """
        MAX_INCIDENT_RETRY = 2

        extracted = await self.validation_sm.llm_functions.extract_incident_info(user_text)

        locality = extracted.get("locality") or self.incident_info.get("locality")
        description = extracted.get("description") or self.incident_info.get("description")

        if not locality and not description:
            self.incident_retry_count += 1
            logger.warning(f"Incident extraction failed (attempt {self.incident_retry_count})")

            if self.incident_retry_count >= MAX_INCIDENT_RETRY:
                self.phase = "CONVERSATION_ENDING"
                return ("Xin lỗi, tôi chưa ghi nhận được thông tin sự cố. "
                        "Quý khách vui lòng gọi lại hoặc liên hệ nhân viên "
                        "để được hỗ trợ trực tiếp. "
                        "Quý khách cần hỗ trợ thêm gì không ạ?")

            return ("Xin lỗi, tôi chưa nghe rõ. Quý khách vui lòng cho biết "
                    "địa điểm và mô tả sự cố ạ?")

        if not locality:
            self.incident_retry_count += 1
            self.incident_info["description"] = description
            logger.warning(f"Missing locality (attempt {self.incident_retry_count})")

            if self.incident_retry_count >= MAX_INCIDENT_RETRY:
                self.phase = "CONVERSATION_ENDING"
                return ("Xin lỗi, tôi chưa xác định được khu vực sự cố. "
                        "Quý khách vui lòng gọi lại hoặc liên hệ nhân viên "
                        "để được hỗ trợ trực tiếp. "
                        "Quý khách cần hỗ trợ thêm gì không ạ?")

            return (f"Tôi ghi nhận sự cố: {description}. "
                    "Quý khách vui lòng cho biết khu vực xảy ra sự cố ạ?")

        if not description:
            self.incident_retry_count += 1
            self.incident_info["locality"] = locality
            logger.warning(f"Missing description (attempt {self.incident_retry_count})")

            if self.incident_retry_count >= MAX_INCIDENT_RETRY:
                self.phase = "CONVERSATION_ENDING"
                return ("Xin lỗi, tôi chưa ghi nhận được mô tả sự cố. "
                        "Quý khách vui lòng gọi lại hoặc liên hệ nhân viên "
                        "để được hỗ trợ trực tiếp. "
                        "Quý khách cần hỗ trợ thêm gì không ạ?")

            return (f"Tôi ghi nhận khu vực {locality}. "
                    "Quý khách vui lòng mô tả tình trạng sự cố ạ?")

        # Save to database
        ctx = self.validation_sm.get_context()
        phone = ctx.phone_number or "unknown"
        name = ctx.customer_name or "unknown"

        self.incident_db.save(phone, name, locality, description)

        self.phase = "CONVERSATION_ENDING"

        logger.info(f"Incident saved: {phone} / {name} / {locality} / {description}")

        return (f"Đã ghi nhận sự cố tại khu vực {locality}, "
                f"nội dung: {description}. "
                f"Chúng tôi sẽ xử lý sớm nhất. Cảm ơn quý khách đã báo cáo. "
                f"Quý khách cần hỗ trợ thêm gì không ạ?")


# ============================================================================
# WEBSOCKET SERVER
# ============================================================================

async def handle_client(websocket):
    """Handle WebSocket client connection"""
    client_id = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
    logger.info(f"Client connected: {client_id}")

    agent = VoiceAgentWithValidationAndIntent()

    try:
        # Send greeting
        greeting = await agent.validation_sm.start()
        greeting_audio = await agent.synthesize(greeting)
        if greeting_audio:
            await websocket.send(greeting_audio)
            # Suppress echo for the duration of TTS playback
            playback_secs = len(greeting_audio) / (Config.SAMPLE_RATE * 2)
            agent._suppress_until = time.time() + playback_secs
            agent.audio_buffer.clear()

        # Process incoming audio
        async for message in websocket:
            if isinstance(message, bytes):
                response_audio = await agent.process_audio(message)

                if response_audio:
                    await websocket.send(response_audio)
                    # Suppress echo for the duration of TTS playback
                    playback_secs = len(response_audio) / (Config.SAMPLE_RATE * 2)
                    agent._suppress_until = time.time() + playback_secs
                    agent.audio_buffer.clear()

    except websockets.exceptions.ConnectionClosed:
        logger.info(f"Client disconnected: {client_id}")
    except Exception as e:
        logger.error(f"Error handling client {client_id}: {e}")
        import traceback
        traceback.print_exc()
    finally:
        logger.info(f"Cleaned up connection: {client_id}")


async def start_server():
    """Start WebSocket server"""
    logger.info(f"Starting WebSocket server on {Config.WS_HOST}:{Config.WS_PORT}")
    ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
    ssl_context.load_cert_chain("server.crt", "server.key")

    async with websockets.serve(
        handle_client,
        Config.WS_HOST,
        Config.WS_PORT,
        ssl=ssl_context,
        max_size=10_000_000
    ):
        logger.info(f"Voice agent listening on ws://{Config.WS_HOST}:{Config.WS_PORT}")
        logger.info("Press Ctrl+C to stop")
        await asyncio.Future()


# ============================================================================
# MAIN
# ============================================================================

async def main():
    """Main entry point"""
    logger.info("="*70)
    logger.info("ELECTRICITY CALL CENTER - PHASE 1 + PHASE 2 (Silero VAD)")
    logger.info("="*70)
    logger.info(f"Phase 1: User Validation")
    logger.info(f"Phase 2: Intent Detection")
    logger.info(f"Phase 3: Intent Handling (TODO)")
    logger.info("")
    logger.info(f"Language:     Vietnamese")
    logger.info(f"LLM Server:   {Config.VLLM_BASE_URL}")
    logger.info(f"TTS Model:    VieNeu ({Config.VIENEU_MODEL_DIR})")
    logger.info(f"Customer DB:  {Config.CUSTOMER_DB_FILE}")
    logger.info(f"WebSocket:    ws://{Config.WS_HOST}:{Config.WS_PORT}")
    logger.info(f"VAD:          Silero VAD v6 (threshold={Config.VAD_THRESHOLD}, "
                f"silence={Config.VAD_MIN_SILENCE_MS}ms, "
                f"min_speech={Config.VAD_MIN_SPEECH_MS}ms)")
    logger.info("="*70)
    logger.info("")

    # Start server
    try:
        await start_server()
    except KeyboardInterrupt:
        logger.info("")
        logger.info("Shutting down...")


if __name__ == "__main__":
    asyncio.run(main())
