import uuid
from typing import Dict, FrozenSet, List, Optional, Set, Tuple, Type

import torch

import ray
from ray.experimental.channel.common import ChannelContext
from ray.experimental.channel.communicator import (
    Communicator,
    ReduceOp,
    TorchTensorAllocator,
)


class AbstractNcclGroup(Communicator):
    """
    A dummy NCCL group for testing.
    """

    def __init__(self, actor_handles: List[ray.actor.ActorHandle]):
        self._actor_handles = actor_handles
        self._rank = None

    def initialize(self, rank: int) -> None:
        self._rank = rank

    def get_rank(self, actor: ray.actor.ActorHandle) -> int:
        return self._actor_handles.index(actor)

    def get_world_size(self) -> int:
        return len(self._actor_handles)

    def get_self_rank(self) -> Optional[int]:
        return self._rank

    def get_actor_handles(self) -> List["ray.actor.ActorHandle"]:
        return self._actor_handles

    def send(self, value: "torch.Tensor", peer_rank: int) -> None:
        raise NotImplementedError

    def recv(
        self,
        shape: Tuple[int],
        dtype: "torch.dtype",
        peer_rank: int,
        allocator: Optional[TorchTensorAllocator] = None,
    ) -> "torch.Tensor":
        raise NotImplementedError

    def allgather(
        self,
        send_buf: "torch.Tensor",
        recv_buf: "torch.Tensor",
    ) -> None:
        raise NotImplementedError

    def allreduce(
        self,
        send_buf: "torch.Tensor",
        recv_buf: "torch.Tensor",
        op: ReduceOp = ReduceOp.SUM,
    ) -> None:
        raise NotImplementedError

    def reducescatter(
        self,
        send_buf: "torch.Tensor",
        recv_buf: "torch.Tensor",
        op: ReduceOp = ReduceOp.SUM,
    ) -> None:
        raise NotImplementedError

    @property
    def recv_stream(self):
        return None

    @property
    def send_stream(self):
        return None

    def destroy(self) -> None:
        pass

    def get_transport_name(self) -> str:
        return "accelerator"

    @classmethod
    def generate_communicator_id(cls) -> str:
        pass


class MockNcclGroupSet:
    def __init__(self):
        # Represents a mapping from a NCCL group ID to a set of actors and a custom
        # NCCL group.
        self.ids_to_actors_and_custom_comms: Dict[
            str, Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[Communicator]]
        ] = {}

    def __call__(
        self,
        actors: List["ray.actor.ActorHandle"],
        custom_nccl_group: Optional[Communicator] = None,
        use_communication_streams: bool = False,
        accelerator_module_name: Optional[str] = None,
        accelerator_communicator_cls: Optional[Type[Communicator]] = None,
    ) -> str:
        group_id = str(uuid.uuid4())
        self.ids_to_actors_and_custom_comms[group_id] = (
            frozenset(actors),
            custom_nccl_group,
        )

        if custom_nccl_group is None:
            ranks = list(range(len(actors)))
        else:
            ranks = [custom_nccl_group.get_rank(actor) for actor in actors]
        init_tasks = [
            actor.__ray_call__.remote(
                mock_do_init_nccl_group,
                group_id,
                rank,
                actors,
                custom_nccl_group,
            )
            for rank, actor in zip(ranks, actors)
        ]
        ray.get(init_tasks, timeout=30)

        ctx = ChannelContext.get_current()
        if custom_nccl_group is not None:
            ctx.communicators[group_id] = custom_nccl_group
        else:
            ctx.communicators[group_id] = AbstractNcclGroup(actors)

        return group_id

    def mock_destroy_nccl_group(self, group_id: str) -> None:
        ctx = ChannelContext.get_current()
        if group_id not in ctx.communicators:
            return

        actors, _ = self.ids_to_actors_and_custom_comms[group_id]
        destroy_tasks = [
            actor.__ray_call__.remote(
                mock_do_destroy_nccl_group,
                group_id,
            )
            for actor in actors
        ]
        ray.wait(destroy_tasks, timeout=30)

        if group_id in self.ids_to_actors_and_custom_comms:
            del self.ids_to_actors_and_custom_comms[group_id]
        ctx.communicators[group_id].destroy()
        del ctx.communicators[group_id]

    def check_teardown(self, nccl_group_ids: List[str]) -> None:
        ctx = ChannelContext.get_current()
        for nccl_group_id in nccl_group_ids:
            assert nccl_group_id not in self.ids_to_actors_and_custom_comms
            assert nccl_group_id not in ctx.communicators


@ray.remote
class CPUTorchTensorWorker:
    def __init__(self):
        self.device = "cpu"

    def return_tensor(
        self, size: int, dtype: Optional[torch.dtype] = None
    ) -> torch.Tensor:
        return torch.ones(size, dtype=dtype, device=self.device)

    def recv(self, tensor: torch.Tensor) -> Tuple[int, int]:
        assert tensor.device == self.device
        return tensor.shape, tensor[0]

    def recv_tensors(self, *tensors) -> Tuple[torch.Tensor, ...]:
        return tuple(tensors)


def mock_do_init_nccl_group(
    self,
    group_id: str,
    rank: int,
    actors: List[ray.actor.ActorHandle],
    custom_nccl_group: Optional[Communicator],
) -> None:
    ctx = ChannelContext.get_current()
    if custom_nccl_group is None:
        nccl_group = AbstractNcclGroup(actors)
        nccl_group.initialize(rank)
        ctx.communicators[group_id] = nccl_group
    else:
        custom_nccl_group.initialize(rank)
        ctx.communicators[group_id] = custom_nccl_group


def mock_do_destroy_nccl_group(self, group_id: str) -> None:
    ctx = ChannelContext.get_current()
    if group_id not in ctx.communicators:
        return
    ctx.communicators[group_id].destroy()
    del ctx.communicators[group_id]


def check_nccl_group_init(
    monkeypatch,
    dag: "ray.dag.DAGNode",
    actors_and_custom_comms: Set[
        Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[Communicator]]
    ],
) -> "ray.dag.CompiledDAG":
    mock_nccl_group_set = MockNcclGroupSet()
    monkeypatch.setattr(
        "ray.dag.compiled_dag_node._init_communicator",
        mock_nccl_group_set,
    )

    compiled_dag = dag.experimental_compile()
    assert (
        set(mock_nccl_group_set.ids_to_actors_and_custom_comms.values())
        == actors_and_custom_comms
    )

    return compiled_dag, mock_nccl_group_set


def check_nccl_group_teardown(
    monkeypatch,
    compiled_dag: "ray.dag.CompiledDAG",
    mock_nccl_group_set: MockNcclGroupSet,
):
    monkeypatch.setattr(
        "ray.dag.compiled_dag_node._destroy_communicator",
        mock_nccl_group_set.mock_destroy_nccl_group,
    )

    created_communicator_ids = compiled_dag._actors_to_created_communicator_id.values()
    compiled_dag.teardown()
    mock_nccl_group_set.check_teardown(created_communicator_ids)
