try:
    import torch
    import vllm
    from vllm.transformers_utils.tokenizer import MistralTokenizer
    from transformers import PreTrainedTokenizerBase
except ImportError:
    raise ImportError('vllm is not installed. Please install it with "pip install vllm"')
from lmformatenforcer import CharacterLevelParser, TokenEnforcer, FormatEnforcerAnalyzer, TokenEnforcerTokenizerData
from lmformatenforcer.integrations.transformers import build_token_enforcer_tokenizer_data
from typing import List, Optional, Union
import math


class VLLMLogitsProcessor:
    def __init__(self, token_enforcer: TokenEnforcer, analyze):
        self.token_enforcer = token_enforcer
        self.analyzer = FormatEnforcerAnalyzer(token_enforcer) if analyze else None
        self.mask: Optional[torch.Tensor] = None

    def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
        token_sequence = input_ids
        if self.analyzer:
            self.analyzer.report_raw_logits(token_sequence, scores.tolist())
        allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence).allowed_tokens
        if self.mask is not None:
            self.mask.fill_(-math.inf)
        else:
            # We create it here because full_like() also copies the device and dtype
            self.mask = torch.full_like(scores, -math.inf)
        self.mask[allowed_tokens] = 0
        scores = scores + self.mask
        return scores


def build_vllm_token_enforcer_tokenizer_data(tokenizer: Union[vllm.LLM, PreTrainedTokenizerBase], use_bitmask: bool = False, vocab_size: int | None = None) -> TokenEnforcerTokenizerData:
    # There are many classes that can be passed here, this logic should work on all of them.
    if vocab_size is None:
        if hasattr(tokenizer, 'llm_engine'):
            vocab_size = tokenizer.llm_engine.get_model_config().get_vocab_size()
    if hasattr(tokenizer, 'get_tokenizer'):
        tokenizer = tokenizer.get_tokenizer()
    if isinstance(tokenizer, MistralTokenizer):
        return build_token_enforcer_tokenizer_data(tokenizer, use_bitmask, vocab_size)
    if hasattr(tokenizer, 'tokenizer'):
        tokenizer = tokenizer.tokenizer
    return build_token_enforcer_tokenizer_data(tokenizer, use_bitmask, vocab_size)


def build_vllm_logits_processor(llm: Union[vllm.LLM, PreTrainedTokenizerBase, TokenEnforcerTokenizerData], 
                                character_level_parser: CharacterLevelParser, 
                                analyze: bool=False) -> VLLMLogitsProcessor:
    """Build the logits processor function that llama.cpp will use to filter the tokens generated by the model. The result
    can be passed in the logits_processor list that is sent to the call or generate() method of llama.cpp models."""
    if not isinstance(llm, TokenEnforcerTokenizerData):
        llm = build_vllm_token_enforcer_tokenizer_data(llm)
    token_enforcer = TokenEnforcer(llm, character_level_parser)
    return VLLMLogitsProcessor(token_enforcer, analyze)


__all__ = ['build_vllm_logits_processor', 'build_vllm_token_enforcer_tokenizer_data']
