# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tokenizer for Grok-2 .tok.json format."""

import functools
import json
from collections.abc import Collection, Set
from pathlib import Path
from typing import Any, Literal, overload

from huggingface_hub import hf_hub_download
from huggingface_hub.utils import (
    EntryNotFoundError,
    HfHubHTTPError,
    RepositoryNotFoundError,
    RevisionNotFoundError,
)
from transformers import BatchEncoding
from transformers.utils import chat_template_utils as hf_chat_utils

from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.logger import init_logger

from .protocol import TokenizerLike

logger = init_logger(__name__)

PAD = "<|pad|>"
EOS = "<|eos|>"
SEP = "<|separator|>"
RESERVED_TOKEN_TEXTS = [f"<|reserved_{i}|>" for i in range(3, 128)]
CONTROL_TOKEN_TEXTS = [f"<|control{i}|>" for i in range(1, 705)]
DEFAULT_SPECIAL_TOKENS = [PAD, SEP, EOS]
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": SEP, "eos": EOS}
DEFAULT_CHAT_TEMPLATE = (
    "{% for message in messages %}"
    "{% if message['role'] == 'user' %}"
    "{{ 'Human: ' + message['content'].strip() + '<|separator|>\\n\\n' }}"
    "{% elif message['role'] == 'system' %}"
    "{{ 'System: ' + message['content'].strip() + '<|separator|>\\n\\n' }}"
    "{% elif message['role'] == 'assistant' %}"
    "{{ 'Assistant: ' + message['content'] + '<|separator|>\\n\\n' }}"
    "{% endif %}"
    "{% endfor %}"
    "{% if add_generation_prompt %}"
    "{{ 'Assistant:' }}"
    "{% endif %}"
)

# Default + separate each single digit.
PAT_STR_B = (
    r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}|"""
    r""" ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
)


def _maybe_load_tokenizer_config(
    model_path: Path,
    *,
    repo_id: str | None,
    revision: str | None,
    download_dir: str | None,
) -> dict[str, Any]:
    config_path = model_path / "tokenizer_config.json"
    if config_path.is_file():
        with config_path.open("r", encoding="utf-8") as f:
            return json.load(f)

    if repo_id is None:
        return {}

    try:
        config_file = hf_hub_download(
            repo_id=repo_id,
            filename="tokenizer_config.json",
            revision=revision,
            cache_dir=download_dir,
        )
    except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError):
        # If the repo, revision, or file does not exist, fall back silently.
        return {}
    except HfHubHTTPError as exc:
        logger.warning(
            "Failed to download tokenizer_config.json from %s. "
            "This may be due to a network or authentication issue. "
            "The default chat template will be used. Error: %s",
            repo_id,
            exc,
        )
        return {}

    try:
        with Path(config_file).open("r", encoding="utf-8") as f:
            return json.load(f)
    except json.JSONDecodeError as exc:
        logger.warning(
            "Failed to parse tokenizer_config.json. "
            "The default chat template will be used. Error: %s",
            exc,
        )
        return {}
    except OSError as exc:
        logger.warning(
            "Failed to open tokenizer_config.json. "
            "The default chat template will be used. Error: %s",
            exc,
        )
        return {}


def _load_tiktoken_encoding(
    vocab_file: Path,
) -> tuple[Any, dict[str, int]]:
    try:
        import tiktoken
    except ImportError as exc:
        raise ImportError("Grok-2 tokenizer requires the `tiktoken` package.") from exc

    with vocab_file.open("rb") as f:
        xtok_dict = json.load(f)

    mergeable_ranks = {
        bytes(item["bytes"]): item["token"]
        for item in xtok_dict.get("regular_tokens", [])
    }
    special_tokens = {
        bytes(item["bytes"]).decode("utf-8", errors="replace"): item["token"]
        for item in xtok_dict.get("special_tokens", [])
    }

    if xtok_dict.get("word_split") == "V1":
        pat_str = PAT_STR_B
    else:
        raise ValueError(f"Unknown word_split: {xtok_dict.get('word_split')!r}")

    pat_str = xtok_dict.get("pat_str", pat_str)

    kwargs = {
        "name": str(vocab_file),
        "pat_str": pat_str,
        "mergeable_ranks": mergeable_ranks,
        "special_tokens": special_tokens,
    }

    if "vocab_size" in xtok_dict:
        kwargs["explicit_n_vocab"] = xtok_dict["vocab_size"]

    tokenizer = tiktoken.Encoding(**kwargs)

    default_allowed_special: set[str] | None = None
    if "default_allowed_special" in xtok_dict:
        default_allowed_special = {
            bytes(bytes_list).decode("utf-8", errors="replace")
            for bytes_list in xtok_dict["default_allowed_special"]
        }

    tokenizer._default_allowed_special = default_allowed_special or set()
    tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS

    def encode_patched(
        self,
        text: str,
        *,
        allowed_special: Literal["all"] | Set[str] = set(),
        disallowed_special: Literal["all"] | Collection[str] = "all",
    ) -> list[int]:
        del disallowed_special
        if isinstance(allowed_special, set):
            allowed_special |= self._default_allowed_special
        return tiktoken.Encoding.encode(
            self,
            text,
            allowed_special=allowed_special,
            disallowed_special=(),
        )

    tokenizer.encode = functools.partial(encode_patched, tokenizer)
    tokenizer._default_allowed_special |= set(DEFAULT_CONTROL_TOKENS.values())
    tokenizer._default_allowed_special |= set(
        CONTROL_TOKEN_TEXTS + RESERVED_TOKEN_TEXTS
    )

    return tokenizer, special_tokens


class Grok2Tokenizer(TokenizerLike):
    @classmethod
    def from_pretrained(
        cls,
        path_or_repo_id: str | Path,
        *args,
        trust_remote_code: bool = False,
        revision: str | None = None,
        download_dir: str | None = None,
        **kwargs,
    ) -> "Grok2Tokenizer":
        if args:
            logger.debug_once("Ignoring extra positional args for Grok2Tokenizer.")

        path = Path(path_or_repo_id)
        if path.is_file():
            vocab_file = path
            model_path = path.parent
            repo_id = None
        elif path.is_dir():
            vocab_file = path / "tokenizer.tok.json"
            model_path = path
            repo_id = None
        else:
            vocab_file = Path(
                hf_hub_download(
                    repo_id=str(path_or_repo_id),
                    filename="tokenizer.tok.json",
                    revision=revision,
                    cache_dir=download_dir,
                )
            )
            model_path = vocab_file.parent
            repo_id = str(path_or_repo_id)

        if not vocab_file.is_file():
            raise FileNotFoundError(f"tokenizer.tok.json not found at {vocab_file}.")

        config = _maybe_load_tokenizer_config(
            model_path,
            repo_id=repo_id,
            revision=revision,
            download_dir=download_dir,
        )

        return cls(
            vocab_file=vocab_file,
            name_or_path=str(path_or_repo_id),
            truncation_side=kwargs.get("truncation_side", "left"),
            chat_template=config.get("chat_template"),
            init_kwargs=config,
        )

    def __init__(
        self,
        *,
        vocab_file: Path,
        name_or_path: str,
        truncation_side: str,
        chat_template: str | None,
        init_kwargs: dict[str, Any] | None = None,
    ) -> None:
        super().__init__()
        self.name_or_path = name_or_path
        self._truncation_side = truncation_side
        self.init_kwargs = init_kwargs or {}
        self._chat_template = chat_template or DEFAULT_CHAT_TEMPLATE

        self._tokenizer, self._special_tokens = _load_tiktoken_encoding(vocab_file)

        self._token_to_id: dict[str, int] = {}
        self._id_to_token: dict[int, str] = {}
        for token, token_id in self._tokenizer._mergeable_ranks.items():
            token_str = token.decode("utf-8", errors="replace")
            self._token_to_id[token_str] = token_id
            self._id_to_token[token_id] = token_str

        for token, token_id in self._special_tokens.items():
            self._token_to_id[token] = token_id
            self._id_to_token[token_id] = token

        bos_token_id = self._special_tokens.get(SEP)
        if bos_token_id is None:
            bos_token_id = self._special_tokens.get(PAD)
        if bos_token_id is None:
            bos_token_id = self._special_tokens.get(EOS)
        if bos_token_id is None:
            bos_token_id = 0
        self._bos_token_id = bos_token_id

        self._eos_token_id = self._special_tokens.get(EOS, self._bos_token_id)
        self._pad_token_id = self._special_tokens.get(PAD, self._eos_token_id)
        self._unk_token_id = self._pad_token_id

    def num_special_tokens_to_add(self) -> int:
        return 0

    @property
    def all_special_tokens(self) -> list[str]:
        return list(self._special_tokens.keys())

    @property
    def all_special_ids(self) -> list[int]:
        return list(self._special_tokens.values())

    @property
    def bos_token_id(self) -> int:
        return self._bos_token_id

    @property
    def eos_token_id(self) -> int:
        return self._eos_token_id

    @property
    def pad_token_id(self) -> int:
        return self._pad_token_id

    @property
    def is_fast(self) -> bool:
        return False

    @property
    def vocab_size(self) -> int:
        return self._tokenizer.n_vocab

    @property
    def max_token_id(self) -> int:
        return self._tokenizer.n_vocab - 1

    @property
    def truncation_side(self) -> str:
        return self._truncation_side

    def get_vocab(self) -> dict[str, int]:
        return dict(self._token_to_id)

    def get_added_vocab(self) -> dict[str, int]:
        return dict(self._special_tokens)

    def _maybe_truncate(self, tokens: list[int], max_length: int | None) -> list[int]:
        if max_length is None or len(tokens) <= max_length:
            return tokens
        if self.truncation_side == "left":
            return tokens[-max_length:]
        return tokens[:max_length]

    def encode(
        self,
        text: str,
        truncation: bool | None = None,
        max_length: int | None = None,
        add_special_tokens: bool = True,
    ) -> list[int]:
        del add_special_tokens
        tokens = self._tokenizer.encode(text)
        if truncation:
            tokens = self._maybe_truncate(tokens, max_length)
        return tokens

    def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
        if isinstance(ids, int):
            ids = [ids]
        if skip_special_tokens:
            ids = [
                token_id
                for token_id in ids
                if token_id not in self._special_tokens.values()
            ]
        return self._tokenizer.decode(ids)

    @overload
    def convert_tokens_to_ids(self, tokens: str) -> int: ...

    @overload
    def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...

    def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
        if isinstance(tokens, str):
            return self._token_to_id.get(tokens, self._unk_token_id)
        return [self._token_to_id.get(token, self._unk_token_id) for token in tokens]

    def convert_ids_to_tokens(
        self, ids: list[int], skip_special_tokens: bool = False
    ) -> list[str]:
        tokens = []
        for token_id in ids:
            if skip_special_tokens and token_id in self._special_tokens.values():
                continue
            tokens.append(self._id_to_token.get(token_id, "<|unk|>"))
        return tokens

    def convert_tokens_to_string(self, tokens: list[str]) -> str:
        token_ids = self.convert_tokens_to_ids(tokens)
        return self.decode(token_ids, skip_special_tokens=False)

    def __call__(
        self,
        text: str | list[str],
        text_pair: str | None = None,
        add_special_tokens: bool = True,
        truncation: bool = False,
        max_length: int | None = None,
    ) -> BatchEncoding:
        if text_pair is not None:
            raise NotImplementedError("text_pair is not supported for Grok2Tokenizer.")

        if isinstance(text, list):
            input_ids_batch: list[list[int]] = [
                self.encode(
                    item,
                    truncation=truncation,
                    max_length=max_length,
                    add_special_tokens=add_special_tokens,
                )
                for item in text
            ]
            attention_mask_batch = [[1] * len(ids) for ids in input_ids_batch]
            return BatchEncoding(
                {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
            )

        input_ids = self.encode(
            text,
            truncation=truncation,
            max_length=max_length,
            add_special_tokens=add_special_tokens,
        )
        attention_mask = [1] * len(input_ids)
        return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask})

    def get_chat_template(
        self, chat_template: str | None, tools: list[dict[str, Any]] | None = None
    ) -> str | None:
        del tools
        return chat_template or self._chat_template

    def apply_chat_template(
        self,
        messages: list[ChatCompletionMessageParam],
        tools: list[dict[str, Any]] | None = None,
        chat_template: str | None = None,
        tokenize: bool = False,
        **kwargs,
    ) -> str | list[int]:
        template = self.get_chat_template(chat_template, tools=tools)
        if template is None:
            raise ValueError(
                "No chat template available. Provide `chat_template` explicitly."
            )
        prompt = hf_chat_utils.apply_chat_template(
            conversation=messages,
            chat_template=template,
            tools=tools,
            **kwargs,
        )
        if tokenize:
            return self.encode(prompt, add_special_tokens=False)
        return prompt
