# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.

"""
This module provides helper functions that are generated by the preprocessor.
The preprocessor read through python's ast and changes the input code.
"""

from typing import Callable, Iterator, Optional, overload
from typing_extensions import deprecated
import warnings
import inspect
from types import BuiltinFunctionType
from functools import lru_cache
from inspect import getmembers

from .utils.logger import log
from .common import *

from ._mlir_helpers.arith import ArithValue


class Executor:
    """
    The Executor class handles dynamic and compile-time (constexpr) execution
    of "for" loops and "if-else-elif" statements.

    Methods:
        set_functions:  Assigns the functions for checking loop bounds and
                        conditional evaluation.

        for_execute: Generates MLIR for OP
        while_execute: Generates MLIR while OP
        if_execute: generate MLIR if OP
    """

    def __init__(self):
        self._is_dynamic_expression = None
        self._loop_execute_range_dynamic = None
        self._if_dynamic = None
        self._while_dynamic = None
        self._compare_executor = None
        self._any_executor = None
        self._all_executor = None
        self._builtin_redirector = None

    def set_functions(
        self,
        *,
        is_dynamic_expression: Callable,
        loop_execute_range_dynamic: Callable,
        if_dynamic: Callable,
        while_dynamic: Callable,
        compare_executor: Callable,
        any_executor: Callable = None,
        all_executor: Callable = None,
        builtin_redirector: Callable = None,
    ):
        self._is_dynamic_expression = is_dynamic_expression
        self._loop_execute_range_dynamic = loop_execute_range_dynamic
        self._if_dynamic = if_dynamic
        self._while_dynamic = while_dynamic
        self._compare_executor = compare_executor
        self._any_executor = any_executor
        self._all_executor = all_executor
        self._builtin_redirector = builtin_redirector

    @staticmethod
    def convert_to_list(x):
        """This function is used to convert x to a list.
        If x is None, return an empty list.
        If x is not a list, return a list containing x.
        Otherwise, return x itself.
        """
        if x is None:
            return []
        if not isinstance(x, list):
            return [x]
        return x

    @staticmethod
    def converge_ret_val(res):
        """This function is used to converge res (the return value) of the function.
        If res is None, return None.
        If res is a list and has only one element, return the element.
        Otherwise, return res itself.
        """
        if res is None:
            return res
        elif isinstance(res, list) and len(res) == 1:
            return res[0]
        return res

    def for_execute(
        self,
        func,
        start,
        stop,
        step,
        write_args=[],
        full_write_args_count=0,
        write_args_names=[],
        unroll=-1,
        unroll_full=False,
        prefetch_stages=None,
    ):
        assert self._loop_execute_range_dynamic, (
            "Functions must be set before execution."
        )
        log().debug("start [%s] stop [%s] step [%s]", start, stop, step)

        return self._loop_execute_range_dynamic(
            func,
            start,
            stop,
            step,
            write_args,
            full_write_args_count,
            write_args_names,
            unroll,
            unroll_full,
            prefetch_stages,
        )

    def if_execute(
        self,
        pred,
        then_block: Callable,
        else_block: Optional[Callable] = None,
        write_args=[],
        full_write_args_count=0,
        write_args_names=[],
    ):
        assert self._if_dynamic, "Functions must be set before execution."

        # MLIR generation
        return self._if_dynamic(
            pred,
            then_block,
            else_block,
            write_args,
            full_write_args_count,
            write_args_names,
        )

    def while_execute(
        self,
        pred,
        while_before_block: Callable,
        while_after_block: Callable,
        write_args=[],
        full_write_args_count=0,
        write_args_names=[],
    ):
        assert self._while_dynamic, "Functions must be set before execution."

        # MLIR generation
        return self._while_dynamic(
            while_before_block,
            while_after_block,
            write_args,
            full_write_args_count,
            write_args_names,
        )


# =============================================================================
# Decorator
# =============================================================================

executor = Executor()


def loop_selector(
    start,
    stop,
    step,
    *,
    write_args=[],
    full_write_args_count=0,
    write_args_names=[],
    unroll=-1,
    unroll_full=False,
    prefetch_stages=None,
):
    log().debug(
        "start [%s] stop [%s] step [%s] write_args [%s] full_write_args_count [%s] write_args_names [%s] unroll [%s] unroll_full [%s] prefetch_stages [%s]",
        start,
        stop,
        step,
        write_args,
        full_write_args_count,
        write_args_names,
        unroll,
        unroll_full,
        prefetch_stages,
    )
    from .typing import Integer, Numeric

    def _maybe_upcast(value):
        if isinstance(value, Integer):
            value = value.ir_value()

        return value

    start = _maybe_upcast(start)
    stop = _maybe_upcast(stop)
    step = _maybe_upcast(step)

    def ir_loop(func):
        return executor.for_execute(
            func,
            start,
            stop,
            step,
            write_args,
            full_write_args_count,
            write_args_names,
            unroll,
            unroll_full,
            prefetch_stages,
        )

    return ir_loop


def if_selector(pred, write_args=[]):
    log().debug("pred [%s] write_args [%s]", pred, write_args)
    # Handle Numeric types here?

    from .typing import Numeric

    if isinstance(pred, Numeric):
        pred = pred.value

    def ir_loop(func):
        return func(pred, *write_args)

    return ir_loop


def while_selector(pred, write_args=[]):
    def ir_while_loop(func):
        return func(pred, *write_args)

    return ir_while_loop


def while_executor(
    pred,
    while_before_block: Callable,
    while_after_block: Callable,
    write_args=[],
    full_write_args_count=0,
    write_args_names=[],
):
    return executor.while_execute(
        pred,
        while_before_block,
        while_after_block,
        write_args,
        full_write_args_count,
        write_args_names,
    )


def if_executor(
    pred,
    then_block: Callable,
    else_block: Optional[Callable] = None,
    write_args=[],
    full_write_args_count=0,
    write_args_names=[],
):
    return executor.if_execute(
        pred,
        then_block,
        else_block,
        write_args,
        full_write_args_count,
        write_args_names,
    )


# =============================================================================
# Range
# =============================================================================


class range:
    """
    A range-like object for dynamic loop iteration in the DSL.

    This class provides a range interface similar to Python's built-in range,
    but is designed to be preprocessed into constructs for dynamic
    loop execution.

    The class supports both single-argument (stop) and three-argument
    (start, stop, step) constructors with additional parameters for loop
    optimization:

    - unroll: Number of iterations to unroll (0 or 1 = no unrolling)
    - unroll_full: Whether to fully unroll the loop
    - prefetch_stages: Number of prefetch stages to generate
    """

    @overload
    def __new__(cls, stop, unroll=0, unroll_full=False, prefetch_stages=None):
        pass

    @overload
    def __new__(
        cls, start, stop, step, unroll=0, unroll_full=False, prefetch_stages=None
    ):
        pass

    def __new__(cls, *args, **kwargs):
        raise DSLRuntimeError("dynamic range should be always preprocessed to IR")

    def __iter__(self) -> Iterator[int]:
        raise DSLRuntimeError("dynamic range should be always preprocessed to IR")


@deprecated(
    "range_dynamic is deprecated and will be removed in the future, please remove it."
)
def range_dynamic(*args, **kwargs):
    raise DSLRuntimeError("range_dynamic should be always preprocessed to IR")


def range_constexpr(*args):
    raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.")


# =============================================================================
# If expressions
# =============================================================================


def const_expr(expression):
    """
    This function is used to check if the expression is a python value.
    If the expression is a python value, return the boolean value of the expression.
    If the expression is a dynamic expression, raise an error.
    """
    from .typing import Numeric

    failed = False

    if isinstance(expression, Numeric):
        if isinstance(expression.value, (int, float, bool)):
            return expression.value
        else:
            failed = True
    elif executor._is_dynamic_expression(expression):
        failed = True

    if failed:
        raise DSLRuntimeError(
            f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).",
            context={
                "If your expression depends on dynamic values": "Remove `const_expr()`",
            },
        )
    return expression


@deprecated(
    "dynamic_expr is deprecated and will be removed in the future, please remove it."
)
def dynamic_expr(expression):
    return expression


# =============================================================================
# Assertion & casting
# =============================================================================


def assert_executor(test, msg=None):
    from .typing import Numeric

    fail = False
    # Implicit convert dynamic expression to bool is not allowed
    # So here explicitly do a None check
    if test is not None and executor._is_dynamic_expression(test):
        if isinstance(test, Numeric):
            try:
                test = test.to(bool)
            except:
                fail = True
        else:
            fail = True

    if not fail:
        assert test, msg
    else:
        raise DSLRuntimeError(
            "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
            suggestion="Please replace with runtime assert.",
        )


def bool_cast(value):
    if executor._is_dynamic_expression(value):
        raise DSLRuntimeError(
            "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
            suggestion="Please explicitly convert to boolean with expressions like comparision.",
        )
    return bool(value)


def compare_executor(left, comparators, ops):
    """
    Executes comparison operations with a left operand and a list of comparators.

    Args:
        left: The leftmost value in the comparison chain
        comparators: A list of values to compare against
        ops: A list of comparison operators to apply

    Returns:
        The result of the comparison chain

    Raises:
        AssertionError: If the executor function is not set before execution
    """
    assert executor._compare_executor is not None, (
        "Function must be set before execution."
    )
    return executor._compare_executor(left, comparators, ops)


def any_executor(iterable):
    """Executes the 'any' operation on an iterable, handling both dynamic and static expressions.

    :param iterable: An iterable to check if any elements evaluate to True
    :type iterable: Iterable
    :return: boolean of Python value or IR value
    :rtype: bool or cutlass.Boolean

    """
    if executor._any_executor and executor._is_dynamic_expression(iterable):
        return executor._any_executor(iterable)
    else:
        return any(iterable)


def all_executor(iterable):
    """Executes the 'all' operation on an iterable, handling both dynamic and static expressions.

    :param iterable: An iterable to check if all elements evaluate to True
    :type iterable: Iterable
    :return: boolean of Python value or IR value
    :rtype: bool or cutlass.Boolean
    """
    if executor._all_executor and executor._is_dynamic_expression(iterable):
        return executor._all_executor(iterable)
    else:
        return all(iterable)


# =============================================================================
# Control flow checks
# =============================================================================
class DSLOptimizationWarning(Warning):
    """
    This warning is used to warn the user about the optimization related issues in DSL.
    """

    def __init__(self, message):
        self.message = message
        super().__init__()

    def __str__(self):
        return self.message


def range_value_check(*args):
    """
    Ensure all `range_constexpr` bounds are compile-time constants (Python ints).
    """
    try:
        args = tuple(arg.__index__() for arg in args)

        # Compute range size and warn if it's too large
        start = 0
        end = 0
        step = 1
        if len(args) == 1:
            end = args[0]
        elif len(args) == 2:
            start = args[0]
            end = args[1]
        elif len(args) == 3:
            start = args[0]
            end = args[1]
            step = args[2]

        range_length = (abs(end - start) - 1) // abs(step) + 1
        if range_length >= 64:
            warnings.warn(
                f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.",
                category=DSLOptimizationWarning,
                stacklevel=2,
            )

        return (start, end, step)
    except:
        raise DSLRuntimeError(
            "`range_constexpr` requires constexpr (compile-time constant) for all arguments.",
            suggestion="Use `range` instead of `range_constexpr`.",
        )


@lru_cache(maxsize=1)
def _get_self_module():
    """
    This function is used to get the owning module of this function.
    """
    return inspect.getmodule(_get_self_module)


@lru_cache(maxsize=16)
def cf_symbol_check(symbol):
    """
    Check if the symbol is control flow symbol from current module.
    """

    failed = False
    name = symbol.__name__
    self_module = _get_self_module()
    if inspect.ismodule(symbol):
        name = "range"
        if not self_module.__name__.startswith(symbol.__name__):
            failed = True
    else:
        owning_module = inspect.getmodule(symbol)
        if owning_module != self_module:
            failed = True

    if failed:
        raise DSLRuntimeError(
            f"Incorrect {symbol.__name__} is used.",
            suggestion=f"Please avoid overriding `{symbol.__name__}` from DSL package.",
        )


def redirect_builtin_function(fcn):
    """
    This function is used to redirect built-in function call
    to the function defined in DSL package.
    """
    # Only redirect if it's a built-in
    if isinstance(fcn, BuiltinFunctionType) and executor._builtin_redirector:
        return executor._builtin_redirector(fcn)
    return fcn


def copy_members(dest, src):
    """
    Copies all non-callable, non-dunder members from src to dest if they exist in src.
    Skips members that are callables or have names starting with double underscores.
    """
    if id(dest) == id(src):
        return

    members = getmembers(dest)
    for name, value in members:
        if (
            name.startswith("__")
            or isinstance(value, Callable)
            or not hasattr(src, name)
        ):
            continue
        setattr(dest, name, getattr(src, name))


def get_locals_or_none(locals, symbols):
    """
    Given a locals() dictionary and a list of symbol names, return a list of their values
    in the same order as the symbols list. If a symbol is not present in locals, None is returned
    for that symbol.
    """
    variables = []
    for symbol in symbols:
        if symbol in locals:
            variables.append(locals[symbol])
        else:
            variables.append(None)
    return variables


def closure_check(closures):
    """
    Check if the closures have any captures
    """
    for closure in closures:
        if closure.__closure__:
            raise DSLRuntimeError(
                f"Function `{closure.__name__}` is a closure that captures variables and is not supported in dynamic control flow",
                suggestion="Please implicitly pass in captured variables as arguments",
            )
