import asyncio
from collections import defaultdict
from typing import Optional, Tuple
from unittest import mock

import torch

import ray
import ray.dag
import ray.experimental.channel as ray_channel
from ray.experimental.channel import nccl_group
from ray.experimental.channel.communicator import TorchTensorAllocator
from ray.experimental.util.types import Device


@ray.remote(num_cpus=0)
class Barrier:
    """
    Barrier that blocks the given number of actors until all actors have
    reached the barrier. This is used to mock out blocking NCCL ops.
    """

    def __init__(self, num_actors=2):
        self.num_actors = num_actors
        self.condition = asyncio.Condition()
        # Buffer for the data that is "sent" between the actors, each entry is
        # one p2p op.
        self.data = {}
        # Buffer for the number of actors seen, each entry is one p2p op.
        self.num_actors_seen = defaultdict(int)

        # Add a new mock for the TorchTensorType.device property
        device_property_patcher = mock.patch(
            "ray.experimental.channel.torch_tensor_type.TorchTensorType.device",
            new_callable=mock.PropertyMock,
            return_value=Device.CPU,
        )
        device_property_patcher.start()

    async def wait(self, idx: int, data=None):
        """
        Wait at barrier until all actors have sent `idx`. One actor should
        provide `data`, and this value will be returned by this method for all
        other actors.
        """
        async with self.condition:
            if data is not None:
                assert idx not in self.data, (self.data, self.num_actors_seen)
                self.data[idx] = data
            self.num_actors_seen[idx] += 1

            if self.num_actors_seen[idx] == self.num_actors:
                # Wake up all tasks waiting on this condition.
                self.condition.notify_all()
            else:
                await self.condition.wait_for(
                    lambda: self.num_actors_seen[idx] == self.num_actors
                )

            if data is None:
                data = self.data[idx]

        return data


class MockCudaStream:
    def __init__(self):
        self.cuda_stream = 0

    def synchronize(self):
        pass


class MockNcclGroup(nccl_group._NcclGroup):
    """
    Mock the internal _NcclGroup to use a barrier actor instead of a NCCL group
    for communication.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # We use the op index to synchronize the sender and receiver at the
        # barrier.
        self.num_ops = defaultdict(int)
        self.barriers = set()

    def send(self, tensor: torch.Tensor, peer_rank: int):
        # "Send" the tensor to the barrier actor.
        barrier_key = sorted([self.get_self_rank(), peer_rank])
        barrier_key = f"barrier-{barrier_key[0]}-{barrier_key[1]}"
        barrier = ray.get_actor(name=barrier_key)
        self.barriers.add(barrier)
        ray.get(barrier.wait.remote(self.num_ops[barrier_key], tensor))
        self.num_ops[barrier_key] += 1

    def recv(
        self,
        shape: Tuple[int],
        dtype: torch.dtype,
        peer_rank: int,
        allocator: Optional[TorchTensorAllocator] = None,
    ):
        # "Receive" the tensor from the barrier actor.
        barrier_key = sorted([self.get_self_rank(), peer_rank])
        barrier_key = f"barrier-{barrier_key[0]}-{barrier_key[1]}"
        barrier = ray.get_actor(name=barrier_key)
        self.barriers.add(barrier)
        received_tensor = ray.get(barrier.wait.remote(self.num_ops[barrier_key]))
        assert (
            allocator is not None
        ), "torch tensor allocator is required for MockNcclGroup"
        buf = allocator(shape, dtype)
        buf[:] = received_tensor[:]
        self.num_ops[barrier_key] += 1
        return buf

    def destroy(self) -> None:
        for barrier in self.barriers:
            ray.kill(barrier)


def start_nccl_mock():
    """
    Patch methods that require CUDA.
    """
    # Mock cupy dependencies.
    nccl_mock = mock.MagicMock()
    nccl_mock.nccl.get_unique_id.return_value = 0
    cp_patcher = mock.patch.dict(
        "sys.modules",
        {
            "cupy.cuda": nccl_mock,
            "cupy": mock.MagicMock(),
            "ray.util.collective.collective_group": mock.MagicMock(),
        },
    )
    cp_patcher.start()

    # Mock send/recv ops to use an actor instead of NCCL.
    ray.experimental.channel.nccl_group._NcclGroup = MockNcclGroup

    # PyTorch mocks.
    stream_patcher = mock.patch(
        "torch.cuda.current_stream", new_callable=lambda: MockCudaStream
    )
    stream_patcher.start()
    new_stream_patcher = mock.patch(
        "torch.cuda.Stream", new_callable=lambda: MockCudaStream
    )
    new_stream_patcher.start()
    tensor_patcher = mock.patch("torch.Tensor.device", torch.device("cuda"))
    tensor_patcher.start()
    tensor_patcher = mock.patch("torch.Tensor.is_cuda", True)
    tensor_patcher.start()
    tensor_allocator_patcher = mock.patch(
        "ray.experimental.channel.torch_tensor_accelerator_channel._torch_tensor_allocator",
        lambda shape, dtype: torch.empty(shape, dtype=dtype),
    )
    tensor_allocator_patcher.start()

    # Add a new mock for the TorchTensorType.device property
    device_property_patcher = mock.patch(
        "ray.experimental.channel.torch_tensor_type.TorchTensorType.device",
        new_callable=mock.PropertyMock,
        return_value=Device.CPU,
    )
    device_property_patcher.start()

    ctx = ray_channel.ChannelContext.get_current()
    ctx.set_torch_device(torch.device("cuda"))


class TracedChannel(ray_channel.shared_memory_channel.Channel):
    """
    Patched Channel that records all write ops for testing.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.ops = []

    def write(self, *args, **kwargs):
        self.ops.append((args, kwargs))
        return super().write(*args, **kwargs)
