"""Utility functions for batch processing."""
import logging
from typing import TYPE_CHECKING, Any, Union

if TYPE_CHECKING:
    from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

AnyTokenizer = Union["PreTrainedTokenizer", "PreTrainedTokenizerFast", Any]

logger = logging.getLogger(__name__)


def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
    """Get tokenizer with cached properties.
    This will patch the tokenizer object in place.
    By default, transformers will recompute multiple tokenizer properties
    each time they are called, leading to a significant slowdown. This
    function caches these properties for faster access.
    Args:
        tokenizer: The tokenizer object.
    Returns:
        The patched tokenizer object.
    """
    chat_template = getattr(tokenizer, "chat_template", None)
    # For VLM, the text tokenizer is wrapped by a processor.
    if hasattr(tokenizer, "tokenizer"):
        tokenizer = tokenizer.tokenizer
        # Some VLM's tokenizer has chat_template attribute (e.g. Qwen/Qwen2-VL-7B-Instruct),
        # however some other VLM's tokenizer does not have chat_template attribute (e.g.
        # mistral-community/pixtral-12b). Therefore, we cache the processor's chat_template.
        if chat_template is None:
            chat_template = getattr(tokenizer, "chat_template", None)

    tokenizer_all_special_ids = set(tokenizer.all_special_ids)
    tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
    # all_special_tokens_extended is removed in transformers v5, used in latest
    # SGLang version. We require this SGLang version bc it's ABI compatible with
    # PyTorch 2.9, which is installed by vLLM.
    # TODO(seiji) remove the attribute completely once vLLM moves to transformers v5.
    tokenizer_all_special_tokens_extended = getattr(
        tokenizer, "all_special_tokens_extended", None
    )
    tokenizer_len = len(tokenizer)

    class CachedTokenizer(tokenizer.__class__):  # type: ignore
        @property
        def all_special_ids(self):
            return tokenizer_all_special_ids

        @property
        def all_special_tokens(self):
            return tokenizer_all_special_tokens

        @property
        def all_special_tokens_extended(self):
            return tokenizer_all_special_tokens_extended

        @property
        def chat_template(self):
            return chat_template

        def __len__(self):
            return tokenizer_len

    CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"

    tokenizer.__class__ = CachedTokenizer
    return tokenizer
