from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

if TYPE_CHECKING:
    import torch

import ray
from ray.dag import (
    ClassMethodNode,
    DAGNode,
)
from ray.dag.constants import COLLECTIVE_OPERATION_KEY, IS_CLASS_METHOD_OUTPUT_KEY
from ray.experimental.channel import ChannelContext
from ray.experimental.channel.torch_tensor_type import Communicator, TorchTensorType
from ray.experimental.util.types import (
    AllGatherOp,
    AllReduceOp,
    ReduceScatterOp,
    _CollectiveOp,
)
from ray.util.annotations import DeveloperAPI


class _CollectiveOperation:
    """
    Represent metadata for a collective communicator collective operation.

    Args:
        inputs: A list of lists of DAGNode. Each nested list inside
            of inputs should contain exactly one object per actor.
            If multiple nested lists are provided, then the order of
            actors should be the same for each nested list.
        op: The collective operation to perform.
        transport: The transport to use for the collective operation.

    Requirements:
    1. Input nodes are unique.
    2. Actor handles are unique.
    3. Actor handles match the custom communicator group if specified.
    """

    def __init__(
        self,
        inputs: List[List[DAGNode]],
        op: _CollectiveOp,
        transport: Optional[Union[str, Communicator]] = None,
    ):
        self._actor_handles: List["ray.actor.ActorHandle"] = []
        for i, input_nodes in enumerate(inputs):
            # Check non-empty input list
            if len(input_nodes) == 0:
                nested_list_error_msg = f" at index {i}" if len(inputs) > 1 else ""
                raise ValueError(
                    f"Expected non-empty input list{nested_list_error_msg}."
                )

            # Check input nodes are DAGNode
            if not all(isinstance(node, DAGNode) for node in input_nodes):
                nested_list_error_msg = (
                    f" at list at index {i}" if len(inputs) > 1 else ""
                )
                raise ValueError(
                    f"Expected all input nodes to be DAGNode{nested_list_error_msg}, "
                    f"but got {input_nodes}."
                )

            # Check unique input nodes
            if len(set(input_nodes)) != len(input_nodes):
                duplicates = [
                    input_node
                    for input_node in input_nodes
                    if input_nodes.count(input_node) > 1
                ]
                nested_list_error_msg = (
                    f" at list at index {i}" if len(inputs) > 1 else ""
                )
                raise ValueError(
                    f"Expected unique input nodes{nested_list_error_msg}, but found duplicates: "
                    f"{duplicates}"
                )

            current_actor_handles = []
            for input_node in input_nodes:
                actor_handle = input_node._get_actor_handle()
                if actor_handle is None:
                    nested_list_error_msg = (
                        f" at list at index {i}" if len(inputs) > 1 else ""
                    )
                    raise ValueError(
                        f"Expected an actor handle from the input node{nested_list_error_msg}"
                    )
                current_actor_handles.append(actor_handle)

            # Check unique actor handles
            if len(set(current_actor_handles)) != len(current_actor_handles):
                invalid_input_nodes = [
                    input_node
                    for input_node in input_nodes
                    if current_actor_handles.count(input_node._get_actor_handle()) > 1
                ]
                nested_list_error_msg = (
                    f" at list at index {i}" if len(inputs) > 1 else ""
                )
                raise ValueError(
                    f"Expected unique actor handles{nested_list_error_msg}, "
                    "but found duplicate actor handles from input nodes: "
                    f"{invalid_input_nodes}"
                )

            if i == 0:
                first_actor_handles = current_actor_handles

            # Check all lists of DAGNode have the same number of nodes
            if len(inputs[0]) != len(inputs[i]):
                raise ValueError(
                    f"Expected all input lists to have the same number of nodes. "
                    f"List at index 0 has length {len(inputs[0])}, but list at "
                    f"index {i} has length {len(inputs[i])}."
                )

            # Check all lists of DAGNode have same set of actor handles
            if set(first_actor_handles) != set(current_actor_handles):
                raise ValueError(
                    f"Expected all input lists to have the same set of actor handles. "
                    f"List at index 0 has actors {set(first_actor_handles)}, but list at "
                    f"index {i} has actors {set(current_actor_handles)}."
                )

            # Check all lists of DAGNode have same order of actor handles
            for j, (first, current) in enumerate(
                zip(first_actor_handles, current_actor_handles)
            ):
                if first != current:
                    raise ValueError(
                        f"Expected all input lists to have the same order of actor handles. "
                        f"List at index 0 has actor {first} at position {j}, but list at "
                        f"index {i} has actor {current} at position {j}."
                    )
        self._actor_handles = current_actor_handles

        self._op = op
        if transport is None:
            transport = TorchTensorType.ACCELERATOR
        self._type_hint = TorchTensorType(transport=transport, _direct_return=True)
        if isinstance(transport, Communicator):
            if set(transport.get_actor_handles()) != set(self._actor_handles):
                raise ValueError(
                    "Expected actor handles to match the custom communicator group"
                )

    def __str__(self) -> str:
        return (
            f"CollectiveOperation("
            f"_actor_handles={self._actor_handles}, "
            f"_op={self._op}, "
            f"_type_hint={self._type_hint})"
        )

    @property
    def actor_handles(self) -> List["ray.actor.ActorHandle"]:
        return self._actor_handles

    @property
    def type_hint(self) -> TorchTensorType:
        return self._type_hint

    def get_communicator(self) -> Communicator:
        if self._type_hint.communicator_id is not None:
            ctx = ChannelContext.get_current()
            communicator = ctx.communicators[self._type_hint.communicator_id]
        elif self._type_hint.get_custom_communicator() is not None:
            communicator = self._type_hint.get_custom_communicator()
        else:
            raise ValueError("Expected a communicator group")
        return communicator

    def execute(
        self, *send_buf: "torch.Tensor"
    ) -> Union["torch.Tensor", Tuple["torch.Tensor", ...]]:
        """
        Call the collective operation on the input tensor(s). Output tensor(s) are
        allocated and returned.

        Args:
            *send_buf: A variable number of torch tensors to send to the collective
                operation. The tensors have the same order as the input nodes.

        Returns:
            A torch tensor or a tuple of torch tensors containing the results of the
            collective operation. The output tensors have the same length and order
            as the input node list of the actor of this operation.
        """
        import torch

        if not all(isinstance(t, torch.Tensor) for t in send_buf):
            raise ValueError("Expected a torch tensor for each input node")

        communicator = self.get_communicator()
        if isinstance(self._op, AllGatherOp):
            assert len(send_buf) == 1
            t = send_buf[0]
            world_size = len(self._actor_handles)
            recv_buf = torch.empty(
                (t.shape[0] * world_size, *t.shape[1:]),
                dtype=t.dtype,
                device=t.device,
            )
            communicator.allgather(t, recv_buf)
        elif isinstance(self._op, AllReduceOp):
            if len(send_buf) == 1:
                t = send_buf[0]
                recv_buf = torch.empty_like(t)
                communicator.allreduce(t, recv_buf, self._op.reduceOp)
            else:
                if not all(t.dtype == send_buf[0].dtype for t in send_buf):
                    raise ValueError(
                        "Expected all input tensors to have the same dtype, "
                        f"but got {[t.dtype for t in send_buf]}"
                    )

                def unflatten_from(flat_buf, bufs):
                    views = []
                    offset = 0
                    for t in bufs:
                        numel = t.numel()
                        t = flat_buf[offset : offset + numel].view(t.shape)
                        views.append(t)
                        offset += numel
                    return tuple(views)

                flat_buf = torch.nn.utils.parameters_to_vector(send_buf)
                communicator.allreduce(flat_buf, flat_buf, self._op.reduceOp)
                recv_buf = unflatten_from(flat_buf, send_buf)
        elif isinstance(self._op, ReduceScatterOp):
            assert len(send_buf) == 1
            t = send_buf[0]
            world_size = len(self._actor_handles)
            if t.shape[0] % world_size != 0:
                raise ValueError(
                    "Expected the first dimension of the input tensor to be divisible "
                    f"by the world size {world_size}"
                )
            recv_buf = torch.empty(
                (t.shape[0] // world_size, *t.shape[1:]),
                dtype=t.dtype,
                device=t.device,
            )
            communicator.reducescatter(t, recv_buf, self._op.reduceOp)
        return recv_buf


@DeveloperAPI
class CollectiveOutputNode(ClassMethodNode):
    """Represent an output node from a communicator collective operation in a Ray DAG."""

    def __init__(
        self,
        method_name: str,
        method_args: Tuple[
            DAGNode,
        ],
        method_kwargs: Dict[str, Any],
        method_options: Dict[str, Any],
        other_args_to_resolve: Dict[str, Any],
    ):
        # Parse the input node(s).
        self._inputs = method_args
        # Parse the collective operation.
        self._collective_op: _CollectiveOperation = other_args_to_resolve.get(
            COLLECTIVE_OPERATION_KEY, None
        )
        self._is_class_method_output: bool = other_args_to_resolve.get(
            IS_CLASS_METHOD_OUTPUT_KEY, False
        )
        if self._collective_op is None and not self._is_class_method_output:
            raise ValueError("Expected a collective operation")

        super().__init__(
            method_name,
            method_args,
            method_kwargs,
            method_options,
            other_args_to_resolve,
        )

    def _copy_impl(
        self,
        new_args: List[Any],
        new_kwargs: Dict[str, Any],
        new_options: Dict[str, Any],
        new_other_args_to_resolve: Dict[str, Any],
    ):
        return CollectiveOutputNode(
            self._method_name,
            new_args,
            new_kwargs,
            new_options,
            other_args_to_resolve=new_other_args_to_resolve,
        )

    def _execute_impl(self, *args, **kwargs):
        raise NotImplementedError(
            "CollectiveOutputNode is only supported with dag.experimental_compile()"
        )

    @property
    def collective_op(self) -> _CollectiveOperation:
        return self._collective_op
