# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.

from typing import Callable, Any, Iterable, Iterator, NamedTuple, Union, get_origin
import dataclasses
import itertools as it
from types import SimpleNamespace

from ..base_dsl.typing import as_numeric, Numeric, Constexpr
from ..base_dsl._mlir_helpers.arith import ArithValue
from ..base_dsl.common import DSLBaseError
from .._mlir import ir

NoneType = type(None)

# =============================================================================
# Tree Utils
# =============================================================================


class DSLTreeFlattenError(DSLBaseError):
    """Exception raised when tree flattening fails due to unsupported types."""

    def __init__(self, msg: str, type_str: str):
        super().__init__(msg)
        self.type_str = type_str


def unzip2(pairs: Iterable[tuple[Any, Any]]) -> tuple[list[Any], list[Any]]:
    """Unzip a sequence of pairs into two lists."""
    lst1, lst2 = [], []
    for x1, x2 in pairs:
        lst1.append(x1)
        lst2.append(x2)
    return lst1, lst2


def get_fully_qualified_class_name(x: Any) -> str:
    """
    Get the fully qualified class name of an object.

    Args:
        x: Any object

    Returns:
        str: Fully qualified class name in format 'module.class_name'

    Example:
        >>> get_fully_qualified_class_name([1, 2, 3])
        'builtins.list'
    """
    return f"{x.__class__.__module__}.{x.__class__.__qualname__}"


def is_frozen_dataclass(obj_or_cls: Any) -> bool:
    """
    Check if an object or class is a frozen dataclass.

    Args:
        obj_or_cls: Either a dataclass instance or class

    Returns:
        bool: True if the object/class is a dataclass declared with frozen=True,
              False otherwise

    Example:
        >>> from dataclasses import dataclass
        >>> @dataclass(frozen=True)
        ... class Point:
        ...     x: int
        ...     y: int
        >>> is_frozen_dataclass(Point)
        True
        >>> is_frozen_dataclass(Point(1, 2))
        True
    """
    cls = obj_or_cls if isinstance(obj_or_cls, type) else obj_or_cls.__class__

    return (
        dataclasses.is_dataclass(cls)
        and getattr(cls, "__dataclass_params__", None) is not None
        and cls.__dataclass_params__.frozen
    )


def is_dynamic_expression(x: Any) -> bool:
    """
    Check if an object implements the DynamicExpression protocol.

    Objects implementing this protocol must have both `__extract_mlir_values__`
    and `__new_from_mlir_values__` methods.

    Args:
        x: Any object to check

    Returns:
        bool: True if the object implements the DynamicExpression protocol,
              False otherwise
    """
    return all(
        hasattr(x, attr)
        for attr in ("__extract_mlir_values__", "__new_from_mlir_values__")
    )


def is_constexpr_field(field: dataclasses.Field) -> bool:
    """
    Check if a field is a constexpr field.
    """
    if field.type is Constexpr:
        return True
    elif get_origin(field.type) is Constexpr:
        return True
    return False


# =============================================================================
# PyTreeDef
# =============================================================================


class NodeType(NamedTuple):
    """
    Represents a node in a pytree structure.

    Attributes:
        name: String representation of the node type
        to_iterable: Function to convert node to iterable form
        from_iterable: Function to reconstruct node from iterable form
    """

    name: str
    to_iterable: Callable
    from_iterable: Callable


class PyTreeDef(NamedTuple):
    """
    Represents the structure definition of a pytree.

    Attributes:
        node_type: The type of this node
        node_metadata: SimpleNamespace metadata associated with this node
        child_treedefs: Tuple of child tree definitions
    """

    node_type: NodeType
    node_metadata: SimpleNamespace
    child_treedefs: tuple["PyTreeDef", ...]


@dataclasses.dataclass(frozen=True)
class Leaf:
    """
    Represents a leaf node in a pytree structure.

    Attributes:
        is_numeric: Whether this leaf contains a `Numeric` value
        is_none: Whether this leaf represents None
        node_metadata: SimpleNamespace metadata associated with this leaf
        ir_type_str: String representation of the IR type
    """

    is_numeric: bool = False
    is_none: bool = False
    node_metadata: SimpleNamespace = None
    ir_type_str: str = None


# =============================================================================
# Default to_iterable and from_iterable
# =============================================================================


def extract_dataclass_members(x: Any) -> tuple[list[str], list[Any]]:
    """
    Extract non-method, non-function attributes from a dataclass instance.

    Args:
        x: A dataclass instance

    Returns:
        tuple: (field_names, field_values) lists
    """
    fields = [field.name for field in dataclasses.fields(x)]

    # If the dataclass has extra fields, raise an error
    for k in x.__dict__.keys():
        if k not in fields:
            raise DSLTreeFlattenError(
                f"`{x}` has extra field `{k}`",
                type_str=get_fully_qualified_class_name(x),
            )

    if not fields:
        return [], []

    # record constexpr fields
    members = []
    constexpr_fields = []
    for field in dataclasses.fields(x):
        if is_constexpr_field(field):
            constexpr_fields.append(field.name)
            fields.remove(field.name)
            v = getattr(x, field.name)
            if is_dynamic_expression(v):
                raise DSLTreeFlattenError(
                    f"`{x}` has dynamic expression field `{field.name}` with a Constexpr type annotation `{field.type}`",
                    type_str=get_fully_qualified_class_name(x),
                )
        else:
            members.append(getattr(x, field.name))

    return fields, members, constexpr_fields


def default_dataclass_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]:
    """
    Convert a dataclass instance to iterable form for tree flattening.

    Extracts all non-method, non-function attributes that don't start with '__'
    and returns them along with metadata about the dataclass.

    Args:
        x: A dataclass instance

    Returns:
        tuple: (metadata, members) where metadata contains type info and field names,
               and members is the list of attribute values
    """
    fields, members, constexpr_fields = extract_dataclass_members(x)

    metadata = SimpleNamespace(
        type_str=get_fully_qualified_class_name(x),
        fields=fields,
        constexpr_fields=constexpr_fields,
        original_obj=x,
    )
    return metadata, members


def set_dataclass_attributes(
    instance: Any,
    fields: list[str],
    values: Iterable[Any],
    constexpr_fields: list[str],
) -> Any:
    """
    Set attributes on a dataclass instance.

    Args:
        instance: The dataclass instance
        fields: List of field names
        values: Iterable of field values
        is_frozen: Whether the dataclass is frozen

    Returns:
        The instance with attributes set
    """
    if not fields:
        return instance

    kwargs = dict(zip(fields, values))
    for field in constexpr_fields:
        kwargs[field] = getattr(instance, field)
    return dataclasses.replace(instance, **kwargs)


def default_dataclass_from_iterable(
    metadata: SimpleNamespace, children: Iterable[Any]
) -> Any:
    """
    Reconstruct a dataclass instance from iterable form.

    Handles both regular and frozen dataclasses appropriately.

    Args:
        metadata: Metadata containing type information and field names
        children: Iterable of attribute values to reconstruct the instance

    Returns:
        The reconstructed dataclass instance
    """
    instance = metadata.original_obj

    new_instance = set_dataclass_attributes(
        instance, metadata.fields, children, metadata.constexpr_fields
    )
    metadata.original_obj = new_instance
    return new_instance


def dynamic_expression_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]:
    """
    Convert a dynamic expression to iterable form.

    Uses the object's `__extract_mlir_values__` method to extract MLIR values.

    Args:
        x: A dynamic expression object

    Returns:
        tuple: (metadata, mlir_values) where metadata marks this as a dynamic expression
               and mlir_values are the extracted MLIR values
    """
    return (
        SimpleNamespace(is_dynamic_expression=1, original_obj=x),
        x.__extract_mlir_values__(),
    )


def dynamic_expression_from_iterable(
    metadata: SimpleNamespace, children: Iterable[Any]
) -> Any:
    """
    Reconstruct a dynamic expression from iterable form.

    Uses the object's `__new_from_mlir_values__` method to reconstruct from MLIR values.

    Args:
        metadata: Metadata containing the original object
        children: Iterable of MLIR values to reconstruct from

    Returns:
        The reconstructed dynamic expression object
    """
    return metadata.original_obj.__new_from_mlir_values__(list(children))


def default_dict_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]:
    """
    Convert a dict to iterable form.
    """
    if isinstance(x, SimpleNamespace):
        keys = list(x.__dict__.keys())
        values = list(x.__dict__.values())
    else:
        keys = list(x.keys())
        values = list(x.values())

    return (
        SimpleNamespace(
            type_str=get_fully_qualified_class_name(x), original_obj=x, fields=keys
        ),
        values,
    )


def default_dict_from_iterable(
    metadata: SimpleNamespace, children: Iterable[Any]
) -> Any:
    """
    Reconstruct a dict from iterable form.
    """
    instance = metadata.original_obj
    fields = metadata.fields
    is_simple_namespace = isinstance(instance, SimpleNamespace)

    for k, v in zip(fields, children):
        if is_simple_namespace:
            setattr(instance, k, v)
        else:
            instance[k] = v

    return instance


# =============================================================================
# Register pytree nodes
# =============================================================================

_node_types: dict[type, NodeType] = {}


def register_pytree_node(ty: type, to_iter: Callable, from_iter: Callable) -> NodeType:
    """
    Register a new node type for pytree operations.

    Args:
        ty: The type to register
        to_iter: Function to convert instances of this type to iterable form
        from_iter: Function to reconstruct instances of this type from iterable form

    Returns:
        NodeType: The created NodeType instance
    """
    nt = NodeType(str(ty), to_iter, from_iter)
    _node_types[ty] = nt
    return nt


def register_default_node_types() -> None:
    """Register default node types for pytree operations."""
    default_registrations = [
        (
            tuple,
            lambda t: (SimpleNamespace(length=len(t)), list(t)),
            lambda _, xs: tuple(xs),
        ),
        (
            list,
            lambda l: (SimpleNamespace(length=len(l)), list(l)),
            lambda _, xs: list(xs),
        ),
        (
            dict,
            default_dict_to_iterable,
            default_dict_from_iterable,
        ),
        (
            SimpleNamespace,
            default_dict_to_iterable,
            default_dict_from_iterable,
        ),
    ]

    for ty, to_iter, from_iter in default_registrations:
        register_pytree_node(ty, to_iter, from_iter)


# Initialize default registrations
register_default_node_types()


# =============================================================================
# tree_flatten and tree_unflatten
# =============================================================================

"""
Behavior of tree_flatten and tree_unflatten, for example:

```python
    a = (1, 2, 3)
    b = MyClass(a=1, b =[1,2,3])
```

yields the following tree:

```python
    tree_a = PyTreeDef(type = 'tuple',
                       metadata = {length = 3},
                       children = [
                           Leaf(type = int),
                           Leaf(type = int),
                           Leaf(type = int),
                       ],
                       )
    flattened_a = [1, 2, 3]
    tree_b = PyTreeDef(type = 'MyClass',
                       metadata = {fields = ['a','b']},
                       children = [
                           PyTreeDef(type = `list`,
                                     metadata = {length = 3},
                                     children = [
                                          Leaf(type=`int`),
                                          Leaf(type=`int`),
                                          Leaf(type=`int`),
                                     ],
                           ),
                           Leaf(type=int),
                       ],
                       )
    flattened_b = [1, 1, 2, 3]
```

Passing the flattened values and PyTreeDef to tree_unflatten to reconstruct the original structure.

``` python
    unflattened_a = tree_unflatten(tree_a, flattened_a)
    unflattened_b = tree_unflatten(tree_b, flattened_b)
```

yields the following structure:

``` python
    unflattened_a = (1, 2, 3)
    unflattened_b = MyClass(a=1, b =[1,2,3])
```

unflattened_a should be structurally identical to a, and unflattened_b should be structurally identical to b.

"""


def tree_flatten(x: Any) -> tuple[list[Any], PyTreeDef]:
    """
    Flatten a nested structure into a flat list of values and a tree definition.

    This function recursively traverses nested data structures (trees) and
    flattens them into a linear list of leaf values, while preserving the
    structure information in a PyTreeDef.

    Args:
        x: The nested structure to flatten

    Returns:
        tuple: (flat_values, treedef) where flat_values is a list of leaf values
               and treedef is the tree structure definition

    Raises:
        DSLTreeFlattenError: If the structure contains unsupported types

    Example:
        >>> tree_flatten([1, [2, 3], 4])
        ([1, 2, 3, 4], PyTreeDef(...))
    """
    children_iter, treedef = _tree_flatten(x)
    return list(children_iter), treedef


def get_registered_node_types_or_insert(x: Any) -> Union[NodeType, None]:
    """
    Get the registered node type for an object, registering it if necessary.

    This function checks if a type is already registered for pytree operations.
    If not, it automatically registers the type based on its characteristics:
    - Dynamic expressions get registered with dynamic expression handlers
    - Dataclasses get registered with default dataclass handlers

    Args:
        x: The object to get or register a node type for

    Returns:
        NodeType or None: The registered node type, or None if the type
                         cannot be registered
    """
    node_type = _node_types.get(type(x))
    if node_type:
        return node_type
    elif is_dynamic_expression(x):
        # If a class implements DynamicExpression protocol, register it before default dataclass one
        return register_pytree_node(
            type(x), dynamic_expression_to_iterable, dynamic_expression_from_iterable
        )
    elif dataclasses.is_dataclass(x):
        return register_pytree_node(
            type(x), default_dataclass_to_iterable, default_dataclass_from_iterable
        )
    else:
        return None


def create_leaf_for_value(
    x: Any,
    is_numeric: bool = False,
    is_none: bool = False,
    node_metadata: SimpleNamespace = None,
    ir_type_str: str = None,
) -> Leaf:
    """
    Create a Leaf node for a given value.

    Args:
        x: The value to create a leaf for
        is_numeric: Whether this is a numeric value
        is_none: Whether this represents None
        node_metadata: Optional metadata
        ir_type_str: Optional IR type string

    Returns:
        Leaf: The created leaf node
    """
    return Leaf(
        is_numeric=is_numeric,
        is_none=is_none,
        node_metadata=node_metadata,
        ir_type_str=ir_type_str or (str(x.type) if hasattr(x, "type") else None),
    )


def _tree_flatten(x: Any) -> tuple[Iterable[Any], Union[PyTreeDef, Leaf]]:
    """
    Internal function to flatten a tree structure.

    This is the core implementation of tree flattening that handles different
    types of objects including None, ArithValue, ir.Value, Numeric types,
    and registered pytree node types.

    Args:
        x: The object to flatten

    Returns:
        tuple: (flattened_values, treedef) where flattened_values is an iterable
               of leaf values and treedef is the tree structure

    Raises:
        DSLTreeFlattenError: If the object type is not supported
    """
    if x is None:
        return [], create_leaf_for_value(x, is_none=True)

    elif isinstance(x, ArithValue) and is_dynamic_expression(x):
        v = x.__extract_mlir_values__()
        return v, create_leaf_for_value(
            x,
            node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x),
            ir_type_str=str(v[0].type),
        )

    elif isinstance(x, ArithValue):
        return [x], create_leaf_for_value(x, is_numeric=True)

    elif isinstance(x, ir.Value):
        return [x], create_leaf_for_value(x)

    elif isinstance(x, Numeric):
        v = x.__extract_mlir_values__()
        return v, create_leaf_for_value(
            x,
            node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x),
            ir_type_str=str(v[0].type),
        )

    else:
        node_type = get_registered_node_types_or_insert(x)
        if node_type:
            node_metadata, children = node_type.to_iterable(x)
            if children is None:
                # Flatten should not return None, it should return an empty list for real empty cases
                raise DSLTreeFlattenError(
                    "Flatten Error: children is None", get_fully_qualified_class_name(x)
                )
            children_flat, child_trees = unzip2(map(_tree_flatten, children))
            flattened = it.chain.from_iterable(children_flat)
            return flattened, PyTreeDef(node_type, node_metadata, tuple(child_trees))

        # Try to convert to numeric
        try:
            nval = as_numeric(x).ir_value()
            return [nval], create_leaf_for_value(nval, is_numeric=True)
        except Exception:
            raise DSLTreeFlattenError(
                "Flatten Error", get_fully_qualified_class_name(x)
            )


def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any:
    """
    Reconstruct a nested structure from a flat list of values and tree definition.

    This is the inverse operation of tree_flatten. It takes the flattened
    values and the tree structure definition to reconstruct the original
    nested structure.

    Args:
        treedef: The tree structure definition from tree_flatten
        xs: List of flat values to reconstruct from

    Returns:
        The reconstructed nested structure

    Example:
        >>> flat_values, treedef = tree_flatten([1, [2, 3], 4])
        >>> tree_unflatten(treedef, flat_values)
        [1, [2, 3], 4]
    """
    return _tree_unflatten(treedef, iter(xs))


def _tree_unflatten(treedef: Union[PyTreeDef, Leaf], xs: Iterator[Any]) -> Any:
    """
    Internal function to reconstruct a tree structure.

    This is the core implementation of tree unflattening that handles
    different types of tree definitions including Leaf nodes and PyTreeDef nodes.

    Args:
        treedef: The tree structure definition
        xs: Iterator of flat values to reconstruct from

    Returns:
        The reconstructed object
    """
    if isinstance(treedef, Leaf):
        if getattr(treedef, "is_none", False):
            return None
        metadata = getattr(treedef, "node_metadata", None)
        if metadata and getattr(metadata, "is_dynamic_expression", False):
            return metadata.original_obj.__new_from_mlir_values__([next(xs)])
        if getattr(treedef, "is_numeric", False):
            return as_numeric(next(xs))
        return next(xs)
    elif isinstance(treedef, PyTreeDef):
        children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs)
        return treedef.node_type.from_iterable(treedef.node_metadata, children)


def _check_tree_equal(lhs: Union[PyTreeDef, Leaf], rhs: Union[PyTreeDef, Leaf]) -> bool:
    """
    Check if two tree definitions are structurally equal.

    This is a helper function for check_tree_equal that recursively compares
    tree structures.

    Args:
        lhs: Left tree definition (PyTreeDef or Leaf)
        rhs: Right tree definition (PyTreeDef or Leaf)

    Returns:
        bool: True if the trees are structurally equal, False otherwise
    """
    if isinstance(lhs, Leaf) and isinstance(rhs, Leaf):
        return lhs.is_none == rhs.is_none and lhs.ir_type_str == rhs.ir_type_str
    elif isinstance(lhs, PyTreeDef) and isinstance(rhs, PyTreeDef):
        lhs_metadata = lhs.node_metadata
        rhs_metadata = rhs.node_metadata

        lhs_fields = getattr(lhs_metadata, "fields", [])
        rhs_fields = getattr(rhs_metadata, "fields", [])
        lhs_constexpr_fields = getattr(lhs_metadata, "constexpr_fields", [])
        rhs_constexpr_fields = getattr(rhs_metadata, "constexpr_fields", [])

        return (
            lhs.node_type == rhs.node_type
            and lhs_fields == rhs_fields
            and lhs_constexpr_fields == rhs_constexpr_fields
            and len(lhs.child_treedefs) == len(rhs.child_treedefs)
            and all(map(_check_tree_equal, lhs.child_treedefs, rhs.child_treedefs))
        )
    else:
        return False


def check_tree_equal(lhs: PyTreeDef, rhs: PyTreeDef) -> int:
    """
    Check if two tree definitions are equal and return the index of first difference.

    This function compares two tree definitions and returns the index of the
    first child that differs, or -1 if they are completely equal.

    Args:
        lhs: Left tree definition
        rhs: Right tree definition

    Returns:
        int: Index of the first differing child, or -1 if trees are equal

    Example:
        >>> treedef1 = tree_flatten([1, [2, 3]])[1]
        >>> treedef2 = tree_flatten([1, [2, 4]])[1]
        >>> check_tree_equal(treedef1, treedef2)
        1  # The second child differs
    """
    assert len(lhs.child_treedefs) == len(rhs.child_treedefs)

    def find_first_difference(
        index_and_pair: tuple[int, tuple[PyTreeDef, PyTreeDef]],
    ) -> int:
        index, (l, r) = index_and_pair
        return index if not _check_tree_equal(l, r) else -1

    differences = map(
        find_first_difference, enumerate(zip(lhs.child_treedefs, rhs.child_treedefs))
    )
    return next((diff for diff in differences if diff != -1), -1)
