"""Python wrapper around the Rust implementation of *harmony*.

The heavy lifting (tokenisation, rendering, parsing, …) is implemented in
Rust.  The thin bindings are available through the private ``openai_harmony``
extension module which is compiled via *maturin* / *PyO3*.

This package provides a small, typed convenience layer that mirrors the public
API of the Rust crate so that it can be used from Python code in an
idiomatic way (``dataclasses``, ``Enum``s, …).
"""

from __future__ import annotations

import functools
import json
from enum import Enum
from typing import (
    AbstractSet,
    Any,
    Collection,
    Dict,
    List,
    Literal,
    Optional,
    Pattern,
    Sequence,
    TypeVar,
    Union,
)

import re
from pydantic import BaseModel, Field

# Re-export the low-level Rust bindings under a private name so that we can
# keep the *public* namespace clean and purely Pythonic.
try:
    from .openai_harmony import (
        HarmonyError as HarmonyError,  # expose the actual Rust error directly
    )
    from .openai_harmony import PyHarmonyEncoding as _PyHarmonyEncoding  # type: ignore
    from .openai_harmony import (
        PyStreamableParser as _PyStreamableParser,  # type: ignore
    )
    from .openai_harmony import (
        load_harmony_encoding as _load_harmony_encoding,  # type: ignore
    )

except ModuleNotFoundError:  # pragma: no cover – raised during type-checking
    # When running *mypy* without the compiled extension in place we still want
    # to succeed.  Therefore we create dummy stubs that satisfy the type
    # checker.  They will, however, raise at **runtime** if accessed.

    class _Stub:  # pylint: disable=too-few-public-methods
        def __getattr__(self, name: str) -> None:  # noqa: D401
            raise RuntimeError(
                "The compiled harmony bindings are not available. Make sure to "
                "build the project with `maturin develop` before running this "
                "code."
            )

    _load_harmony_encoding = _Stub()  # type: ignore
    _PyHarmonyEncoding = _Stub()  # type: ignore
    _PyStreamableParser = _Stub()  # type: ignore
    _HarmonyError = RuntimeError


def _special_token_regex(tokens: frozenset[str]) -> Pattern[str]:
    inner = "|".join(re.escape(token) for token in tokens)
    return re.compile(f"({inner})")


def raise_disallowed_special_token(token: str) -> None:
    raise HarmonyError(
        "Encountered text corresponding to disallowed special token "
        f"{token!r}.\n"
        "If you want this text to be encoded as a special token, "
        f"pass it to `allowed_special`, e.g. `allowed_special={{'{token}', ...}}`.\n"
        "If you want this text to be encoded as normal text, disable the check for this token "
        f"by passing `disallowed_special=(enc.special_tokens_set - {{'{token}'}})`.\n"
        "To disable this check for all special tokens, pass `disallowed_special=()`.\n"
    )


# ---------------------------------------------------------------------------
# Chat-related data-structures (mirroring ``src/chat.rs``)
# ---------------------------------------------------------------------------


class Role(str, Enum):
    """The role of a message author (mirrors ``chat::Role``)."""

    USER = "user"
    ASSISTANT = "assistant"
    SYSTEM = "system"
    DEVELOPER = "developer"
    TOOL = "tool"

    @classmethod
    def _missing_(cls, value: object) -> "Role":  # type: ignore[override]
        raise ValueError(f"Unknown role: {value!r}")


class Author(BaseModel):
    role: Role
    name: Optional[str] = None

    @classmethod
    def new(cls, role: Role, name: str) -> "Author":  # noqa: D401 – keep parity with Rust API
        return cls(role=role, name=name)


# Content hierarchy ---------------------------------------------------------


T = TypeVar("T")


class Content(BaseModel):  # noqa: D101 – simple wrapper
    def to_dict(self) -> Dict[str, Any]:
        raise NotImplementedError


class TextContent(Content):
    text: str

    def to_dict(self) -> Dict[str, Any]:
        return {"type": "text", "text": self.text}


class ToolDescription(BaseModel):
    name: str
    description: str
    parameters: Optional[dict] = None

    @classmethod
    def new(
        cls, name: str, description: str, parameters: Optional[dict] = None
    ) -> "ToolDescription":  # noqa: D401
        return cls(name=name, description=description, parameters=parameters)


class ReasoningEffort(str, Enum):
    LOW = "Low"
    MEDIUM = "Medium"
    HIGH = "High"


class ChannelConfig(BaseModel):
    valid_channels: List[str]
    channel_required: bool

    @classmethod
    def require_channels(cls, channels: List[str]) -> "ChannelConfig":  # noqa: D401
        return cls(valid_channels=channels, channel_required=True)


class ToolNamespaceConfig(BaseModel):
    name: str
    description: Optional[str] = None
    tools: List[ToolDescription]

    @staticmethod
    def browser() -> "ToolNamespaceConfig":
        from .openai_harmony import (
            get_tool_namespace_config as _get_tool_namespace_config,
        )

        cfg = _get_tool_namespace_config("browser")
        return ToolNamespaceConfig(**cfg)

    @staticmethod
    def python() -> "ToolNamespaceConfig":
        from .openai_harmony import (
            get_tool_namespace_config as _get_tool_namespace_config,
        )

        cfg = _get_tool_namespace_config("python")
        return ToolNamespaceConfig(**cfg)


class SystemContent(Content):
    model_identity: Optional[str] = (
        "You are ChatGPT, a large language model trained by OpenAI."
    )
    reasoning_effort: Optional[ReasoningEffort] = ReasoningEffort.MEDIUM
    conversation_start_date: Optional[str] = None
    knowledge_cutoff: Optional[str] = "2024-06"
    channel_config: Optional[ChannelConfig] = Field(
        default_factory=lambda: ChannelConfig.require_channels(
            ["analysis", "commentary", "final"]
        )
    )
    tools: Optional[dict[str, ToolNamespaceConfig]] = None

    @classmethod
    def new(cls) -> "SystemContent":
        return cls()

    # Fluent interface ------------------------------------------------------

    def with_model_identity(self, model_identity: str) -> "SystemContent":
        self.model_identity = model_identity
        return self

    def with_reasoning_effort(
        self, reasoning_effort: ReasoningEffort
    ) -> "SystemContent":
        self.reasoning_effort = reasoning_effort
        return self

    def with_conversation_start_date(
        self, conversation_start_date: str
    ) -> "SystemContent":
        self.conversation_start_date = conversation_start_date
        return self

    def with_knowledge_cutoff(self, knowledge_cutoff: str) -> "SystemContent":
        self.knowledge_cutoff = knowledge_cutoff
        return self

    def with_channel_config(self, channel_config: ChannelConfig) -> "SystemContent":
        self.channel_config = channel_config
        return self

    def with_required_channels(self, channels: list[str]) -> "SystemContent":
        self.channel_config = ChannelConfig.require_channels(channels)
        return self

    def with_tools(self, ns_config: ToolNamespaceConfig) -> "SystemContent":
        if self.tools is None:
            self.tools = {}
        self.tools[ns_config.name] = ns_config
        return self

    def with_browser_tool(self) -> "SystemContent":
        return self.with_tools(ToolNamespaceConfig.browser())

    def with_python_tool(self) -> "SystemContent":
        return self.with_tools(ToolNamespaceConfig.python())

    def to_dict(self) -> dict:
        out = self.model_dump(exclude_none=True)
        out["type"] = "system_content"
        return out

    @classmethod
    def from_dict(cls, raw: dict) -> "SystemContent":
        return cls(**raw)


class DeveloperContent(Content):
    instructions: Optional[str] = None
    tools: Optional[dict[str, ToolNamespaceConfig]] = None

    @classmethod
    def new(cls) -> "DeveloperContent":
        return cls()

    def with_instructions(self, instructions: str) -> "DeveloperContent":
        self.instructions = instructions
        return self

    def with_tools(self, ns_config: ToolNamespaceConfig) -> "DeveloperContent":
        if self.tools is None:
            self.tools = {}
        self.tools[ns_config.name] = ns_config
        return self

    def with_function_tools(
        self, tools: Sequence[ToolDescription]
    ) -> "DeveloperContent":
        return self.with_tools(
            ToolNamespaceConfig(name="functions", description=None, tools=list(tools))
        )

    def to_dict(self) -> dict:
        out = self.model_dump(exclude_none=True)
        out["type"] = "developer_content"
        return out

    @classmethod
    def from_dict(cls, raw: dict) -> "DeveloperContent":
        return cls(**raw)


# Message & Conversation -----------------------------------------------------


class Message(BaseModel):
    author: Author
    content: List[Content] = Field(default_factory=list)
    channel: Optional[str] = None
    recipient: Optional[str] = None
    content_type: Optional[str] = None

    # ------------------------------------------------------------------
    # Convenience constructors (mirroring the Rust API)
    # ------------------------------------------------------------------

    @classmethod
    def from_author_and_content(
        cls, author: Author, content: Union[str, Content]
    ) -> "Message":
        if isinstance(content, str):
            content = TextContent(text=content)
        return cls(author=author, content=[content])

    @classmethod
    def from_role_and_content(
        cls, role: Role, content: Union[str, Content]
    ) -> "Message":  # noqa: D401 – parity with Rust API
        return cls.from_author_and_content(Author(role=role), content)

    @classmethod
    def from_role_and_contents(
        cls, role: Role, contents: Sequence[Content]
    ) -> "Message":
        return cls(author=Author(role=role), content=list(contents))

    # ------------------------------------------------------------------
    # Builder helpers
    # ------------------------------------------------------------------

    def adding_content(self, content: Union[str, Content]) -> "Message":
        if isinstance(content, str):
            content = TextContent(text=content)
        self.content.append(content)
        return self

    def with_channel(self, channel: str) -> "Message":
        self.channel = channel
        return self

    def with_recipient(self, recipient: str) -> "Message":
        self.recipient = recipient
        return self

    def with_content_type(self, content_type: str) -> "Message":
        self.content_type = content_type
        return self

    # ------------------------------------------------------------------
    # Serialisation helpers
    # ------------------------------------------------------------------

    def to_dict(self) -> Dict[str, Any]:  # noqa: D401 – simple mapper
        out: Dict[str, Any] = {
            **self.author.model_dump(),
            "content": [c.to_dict() for c in self.content],
        }
        if self.channel is not None:
            out["channel"] = self.channel
        if self.recipient is not None:
            out["recipient"] = self.recipient
        if self.content_type is not None:
            out["content_type"] = self.content_type
        return out

    def to_json(self) -> str:  # noqa: D401
        return json.dumps(self.to_dict())

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "Message":
        # Simple, sufficient implementation for test-roundtrip purposes.
        role = Role(data["role"])
        author = Author(role=role, name=data.get("name"))

        contents: List[Content] = []

        raw_content = data["content"]

        # The Rust side serialises *single* text contents as a **plain string**
        # for convenience.  Detect this shortcut and normalise it to the list
        # representation that the rest of the Python code expects.
        if isinstance(raw_content, str):
            raw_content = [{"type": "text", "text": raw_content}]

        for raw in raw_content:
            if raw.get("type") == "text":
                contents.append(TextContent(**raw))
            elif raw.get("type") == "system_content":
                contents.append(SystemContent(**raw))
            elif raw.get("type") == "developer_content":
                contents.append(DeveloperContent(**raw))
            else:  # pragma: no cover – unknown variant
                raise ValueError(f"Unknown content variant: {raw}")

        msg = cls(author=author, content=contents)
        msg.channel = data.get("channel")
        msg.recipient = data.get("recipient")
        msg.content_type = data.get("content_type")
        return msg


class Conversation(BaseModel):
    messages: List[Message] = Field(default_factory=list)

    @classmethod
    def from_messages(cls, messages: Sequence[Message]) -> "Conversation":  # noqa: D401
        return cls(messages=list(messages))

    def __iter__(self):
        return iter(self.messages)

    # Serialisation helpers -------------------------------------------------

    def to_dict(self) -> Dict[str, Any]:  # noqa: D401
        return {"messages": [m.to_dict() for m in self.messages]}

    def to_json(self) -> str:  # noqa: D401
        return json.dumps(self.to_dict())

    @classmethod
    def from_json(cls, payload: str) -> "Conversation":  # noqa: D401
        data = json.loads(payload)
        return cls(messages=[Message.from_dict(m) for m in data["messages"]])


# ---------------------------------------------------------------------------
# Encoding interaction (thin wrappers around the Rust bindings)
# ---------------------------------------------------------------------------


class RenderConversationConfig(BaseModel):
    auto_drop_analysis: bool = True


class RenderOptions(BaseModel):
    conversation_has_function_tools: bool = False


class HarmonyEncoding:
    """High-level wrapper around the Rust ``PyHarmonyEncoding`` class."""

    def __init__(self, inner: _PyHarmonyEncoding):
        self._inner = inner

    # ------------------------------------------------------------------
    # Delegated helpers
    # ------------------------------------------------------------------

    @property
    def name(self) -> str:  # noqa: D401
        return self._inner.name  # type: ignore[attr-defined]

    @functools.cached_property
    def special_tokens_set(self) -> set[str]:
        return set(self._inner.special_tokens())

    # -- Rendering -----------------------------------------------------

    def render_conversation_for_completion(
        self,
        conversation: Conversation,
        next_turn_role: Role,
        config: Optional[RenderConversationConfig] = None,
    ) -> List[int]:
        """
        Render a conversation for completion.
        Args:
            conversation: Conversation object
            next_turn_role: Role for the next turn
            config: Optional RenderConversationConfig (default auto_drop_analysis=True)
        """
        if config is None:
            config_dict = {"auto_drop_analysis": True}
        else:
            config_dict = {"auto_drop_analysis": config.auto_drop_analysis}
        return self._inner.render_conversation_for_completion(
            conversation_json=conversation.to_json(),
            next_turn_role=str(next_turn_role.value),
            config=config_dict,
        )

    def render_conversation(
        self,
        conversation: Conversation,
        config: Optional[RenderConversationConfig] = None,
    ) -> List[int]:
        """Render a conversation without appending a new role."""
        if config is None:
            config_dict = {"auto_drop_analysis": True}
        else:
            config_dict = {"auto_drop_analysis": config.auto_drop_analysis}
        return self._inner.render_conversation(
            conversation_json=conversation.to_json(),
            config=config_dict,
        )

    def render_conversation_for_training(
        self,
        conversation: Conversation,
        config: Optional[RenderConversationConfig] = None,
    ) -> List[int]:
        """Render a conversation for training."""
        if config is None:
            config_dict = {"auto_drop_analysis": True}
        else:
            config_dict = {"auto_drop_analysis": config.auto_drop_analysis}
        return self._inner.render_conversation_for_training(
            conversation_json=conversation.to_json(),
            config=config_dict,
        )

    def render(
        self, message: Message, render_options: Optional[RenderOptions] = None
    ) -> List[int]:
        """Render a single message into tokens."""
        if render_options is None:
            render_options_dict = {"conversation_has_function_tools": False}
        else:
            render_options_dict = {
                "conversation_has_function_tools": render_options.conversation_has_function_tools
            }

        return self._inner.render(
            message_json=message.to_json(), render_options=render_options_dict
        )

    # -- Parsing -------------------------------------------------------

    def parse_messages_from_completion_tokens(
        self,
        tokens: Sequence[int],
        role: Optional[Role] | None = None,
        *,
        strict: bool = True,
    ) -> List[Message]:
        raw_json: str = self._inner.parse_messages_from_completion_tokens(
            list(tokens), None if role is None else str(role.value), strict
        )
        return [Message.from_dict(m) for m in json.loads(raw_json)]

    # -- Token decoding ------------------------------------------------

    def decode_utf8(self, tokens: Sequence[int]) -> str:
        """Decode a list of tokens into a UTF-8 string. Will raise an error if the tokens result in invalid UTF-8. Use decode if you want to replace invalid UTF-8 with the unicode replacement character."""
        return self._inner.decode_utf8(list(tokens))

    def encode(
        self,
        text: str,
        *,
        allowed_special: Literal["all"] | AbstractSet[str] = set(),
        disallowed_special: Literal["all"] | Collection[str] = "all",
    ) -> list[int]:
        """Encodes a string into tokens.

        Special tokens are artificial tokens used to unlock capabilities from a model,
        such as fill-in-the-middle. So we want to be careful about accidentally encoding special
        tokens, since they can be used to trick a model into doing something we don't want it to do.

        Hence, by default, encode will raise an error if it encounters text that corresponds
        to a special token. This can be controlled on a per-token level using the `allowed_special`
        and `disallowed_special` parameters. In particular:
        - Setting `disallowed_special` to () will prevent this function from raising errors and
          cause all text corresponding to special tokens to be encoded as natural text.
        - Setting `allowed_special` to "all" will cause this function to treat all text
          corresponding to special tokens to be encoded as special tokens.

        ```
        >>> enc.encode("hello world")
        [31373, 995]
        >>> enc.encode("<|endoftext|>", allowed_special={"<|endoftext|>"})
        [50256]
        >>> enc.encode("<|endoftext|>", allowed_special="all")
        [50256]
        >>> enc.encode("<|endoftext|>")
        # Raises ValueError
        >>> enc.encode("<|endoftext|>", disallowed_special=())
        [27, 91, 437, 1659, 5239, 91, 29]
        ```
        """
        if allowed_special == "all":
            allowed_special = self.special_tokens_set
        if disallowed_special == "all":
            disallowed_special = self.special_tokens_set - set(allowed_special)
        if disallowed_special:
            if not isinstance(disallowed_special, frozenset):
                disallowed_special = frozenset(disallowed_special)
            if match := _special_token_regex(disallowed_special).search(text):
                raise_disallowed_special_token(match.group())

        try:
            return self._inner.encode(text, list(allowed_special))
        except UnicodeEncodeError:
            text = text.encode("utf-16", "surrogatepass").decode("utf-16", "replace")
            return self._inner.encode(text, list(allowed_special))

    def decode(self, tokens: Sequence[int], errors: str = "replace") -> str:
        """Decodes a list of tokens into a string.

        WARNING: the default behaviour of this function is lossy, since decoded bytes are not
        guaranteed to be valid UTF-8. You can use `decode_utf8` if you want to raise an error on invalid UTF-8.

        ```
        >>> enc.decode([31373, 995])
        'hello world'
        ```
        """
        data = bytes(self._inner.decode_bytes(list(tokens)))
        return data.decode("utf-8", errors=errors)

    def is_special_token(self, token: int) -> bool:
        """Returns if an individual token is a special token"""
        return self._inner.is_special_token(token)

    # -- Stop tokens --------------------------------------------------

    def stop_tokens(self) -> List[int]:
        return self._inner.stop_tokens()

    def stop_tokens_for_assistant_actions(self) -> List[int]:
        return self._inner.stop_tokens_for_assistant_actions()


class StreamState(Enum):
    EXPECT_START = "ExpectStart"
    HEADER = "Header"
    CONTENT = "Content"


class StreamableParser:
    """Incremental parser over completion tokens."""

    def __init__(
        self,
        encoding: HarmonyEncoding,
        role: Role | None,
        *,
        strict: bool = True,
    ) -> None:
        role_str = str(role.value) if role is not None else None
        self._inner = _PyStreamableParser(encoding._inner, role_str, strict)

    def process(self, token: int) -> "StreamableParser":
        self._inner.process(token)
        return self

    def process_eos(self) -> "StreamableParser":
        self._inner.process_eos()
        return self

    @property
    def current_content(self) -> str:
        return self._inner.current_content

    @property
    def current_role(self) -> Optional[Role]:
        raw = self._inner.current_role
        return Role(raw) if raw is not None else None

    @property
    def current_content_type(self) -> Optional[str]:
        return self._inner.current_content_type

    @property
    def last_content_delta(self) -> Optional[str]:
        return self._inner.last_content_delta

    @property
    def messages(self) -> List[Message]:
        raw = self._inner.messages
        return [Message.from_dict(m) for m in json.loads(raw)]

    @property
    def tokens(self) -> List[int]:
        return self._inner.tokens

    @property
    def state_data(self) -> Dict[str, Any]:
        """Return a JSON string representing the parser's internal state."""
        return json.loads(self._inner.state)

    @property
    def state(self) -> StreamState:
        data = self.state_data
        return StreamState(data["state"])

    @property
    def current_recipient(self) -> Optional[str]:
        return self._inner.current_recipient

    @property
    def current_channel(self) -> Optional[str]:
        return self._inner.current_channel


# Public helper --------------------------------------------------------------


def load_harmony_encoding(name: str | "HarmonyEncodingName") -> HarmonyEncoding:  # type: ignore[name-defined]
    """Load an encoding by *name* (delegates to the Rust implementation)."""

    # Allow both strings and enum values.
    if not isinstance(name, str):
        name = str(name)

    inner: _PyHarmonyEncoding = _load_harmony_encoding(name)
    return HarmonyEncoding(inner)


# For *mypy* we expose a minimal stub of the `HarmonyEncodingName` enum.  At
# **runtime** the user is expected to pass the *string* names because the Rust
# side only operates on strings anyway.


class HarmonyEncodingName(str, Enum):  # noqa: D101 – simple enum stub
    HARMONY_GPT_OSS = "HarmonyGptOss"

    def __str__(self) -> str:  # noqa: D401
        return str(self.value)


# What should be re-exported when the user does ``from harmony import *``?
__all__ = [
    "Role",
    "Author",
    "Content",
    "TextContent",
    "DeveloperContent",
    "ToolDescription",
    "SystemContent",
    "Message",
    "Conversation",
    "HarmonyEncoding",
    "HarmonyEncodingName",
    "load_harmony_encoding",
    "StreamableParser",
    "StreamState",
    "HarmonyError",
]
