import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Set, Tuple, Union

from ray.experimental.util.types import Device

if TYPE_CHECKING:
    import numpy as np
    import torch


_TORCH_WARNING_FILTER_ACTIVATE = True


class _SerializationContext:
    def __init__(self):
        # If true, then tensors found in the data to serialize are extracted
        # and the caller should send them through an external transport.
        self._use_external_transport: bool = False
        # If _use_external_transport is True, then these are
        # the tensors that should be sent or received
        # out-of-band, through the external transport.
        self._out_of_band_tensors: List["torch.Tensor"] = []
        # During serialization, tensors sent out-of-band are replaced with
        # integer placeholders. This tracks the set of placeholders seen.
        self._deserialized_tensor_placeholders: Set[int] = set()

        # Buffer for transferring data between tasks in the same worker process.
        # The key is the channel ID, and the value is the data. We don't use a
        # lock when reading/writing the buffer because a DAG node actor will only
        # execute one task at a time in `do_exec_tasks`. It will not execute multiple
        # Ray tasks on a single actor simultaneously.
        self.intra_process_channel_buffers: Dict[str, Any] = {}
        # The number of readers for each channel. When the number of readers
        # reaches 0, remove the data from the buffer.
        self.channel_id_to_num_readers: Dict[str, int] = {}

    def set_target_device(self, device: Device) -> None:
        self._target_device = device

    def set_data(self, channel_id: str, value: Any, num_readers: int) -> None:
        assert num_readers > 0, "num_readers must be greater than 0."
        assert (
            channel_id not in self.intra_process_channel_buffers
        ), f"Channel {channel_id} already exists in the buffer."
        assert (
            channel_id not in self.channel_id_to_num_readers
        ), f"Channel {channel_id} already exists in the channel_id_to_num_readers."

        self.intra_process_channel_buffers[channel_id] = value
        self.channel_id_to_num_readers[channel_id] = num_readers

    def has_data(self, channel_id: str) -> bool:
        return channel_id in self.intra_process_channel_buffers

    def get_data(self, channel_id: str) -> Any:
        assert (
            channel_id in self.intra_process_channel_buffers
        ), f"Channel {channel_id} does not exist in the buffer."
        assert (
            channel_id in self.channel_id_to_num_readers
        ), f"Channel {channel_id} does not exist in the channel_id_to_num_readers."

        self.channel_id_to_num_readers[channel_id] -= 1
        if self.channel_id_to_num_readers[channel_id] == 0:
            # All readers have read the data, so we can remove it.
            self.channel_id_to_num_readers.pop(channel_id)
            return self.intra_process_channel_buffers.pop(channel_id)
        return self.intra_process_channel_buffers[channel_id]

    def reset_data(self, channel_id: str) -> None:
        self.intra_process_channel_buffers.pop(channel_id, None)
        self.channel_id_to_num_readers.pop(channel_id, None)

    def set_use_external_transport(self, use_external_transport: bool) -> None:
        self._use_external_transport = use_external_transport

    @property
    def use_external_transport(self) -> bool:
        return self._use_external_transport

    def reset_out_of_band_tensors(
        self, tensors: List["torch.Tensor"]
    ) -> Tuple[List["torch.Tensor"], Set[int]]:
        """
        Return and reset the out-of-band tensors and all tensor placeholders
        that were deserialized since the last call to reset.
        """
        prev_tensors = self._out_of_band_tensors
        deserialized_tensor_placeholders = self._deserialized_tensor_placeholders
        self._out_of_band_tensors = tensors
        self._deserialized_tensor_placeholders = set()
        return prev_tensors, deserialized_tensor_placeholders

    def serialize_tensor(
        self, tensor: "torch.Tensor"
    ) -> Union[int, Tuple["np.ndarray", "torch.dtype", str]]:
        from ray.experimental.channel import ChannelContext

        ctx = ChannelContext.get_current()
        if self._use_external_transport and (
            ctx._torch_device is None or ctx._torch_device == tensor.device
        ):
            # External transport is enabled and we found a tensor that matches
            # our device.  Add the actual tensor to a buffer. The buffer of
            # tensors should later be popped by the caller and sent via
            # external transport.
            self._out_of_band_tensors.append(tensor)
            # Return a placeholder.
            return len(self._out_of_band_tensors) - 1

        return self.serialize_to_numpy_or_scalar(tensor)

    def serialize_to_numpy_or_scalar(
        self, tensor: "torch.Tensor"
    ) -> Tuple[Union["np.ndarray", Any], "torch.dtype", str]:
        """
        Serialize a tensor to a numpy array,
        or a scalar when the tensor is 0-dim.
        """
        import torch

        tensor_device_type = tensor.device.type

        # Transfer through Ray's shared memory store for now.
        # TODO(swang): This requires two copies, one to transfer from GPU to
        # CPU and another from CPU to shared memory. Ideally we should elide
        # the first copy and memcpy directly from GPU to the shared memory
        # buffer.
        if tensor_device_type != "cpu":
            tensor = tensor.to("cpu")

        # Numpy does not have an equivalent dtype for all torch dtypes, so
        # instead of casting directly to numpy:
        # 1) for non-scalar tensors, we first use a view with a common dtype (uint8)
        #    and then view as numpy array.
        # 2) for scalar tensors, we cannot use a uint8 view when the size differs,
        #    so we save the original item and type information.
        if tensor.dim() > 0:
            return (tensor.view(torch.uint8).numpy(), tensor.dtype, tensor_device_type)
        else:
            return (tensor.item(), tensor.dtype, tensor_device_type)

    def deserialize_tensor(
        self,
        val: Union[Tuple["np.ndarray", "torch.dtype", str], int],
        target_device: Device,
    ):

        # Found a placeholder for a tensor that was serialized via accelerator.
        # Replace it with the corresponding deserialized tensor.
        if isinstance(val, int):
            placeholder = val
            self._deserialized_tensor_placeholders.add(placeholder)
            assert placeholder < len(self._out_of_band_tensors), (
                "placeholder",
                placeholder,
                "out_of_band_tensors",
                self._out_of_band_tensors,
            )
            tensor = self._out_of_band_tensors[placeholder]
            if target_device == Device.CPU:
                tensor = tensor.to("cpu")
            return tensor

        np_array, dtype, tensor_device_type = val
        return self.deserialize_from_numpy_or_scalar(
            np_array, dtype, tensor_device_type, target_device
        )

    def deserialize_from_numpy_or_scalar(
        self,
        np_array: Union["np.ndarray", Any],
        dtype: "torch.dtype",
        tensor_device_type: str,
        target_device: Device,
    ):
        import numpy as np
        import torch

        if target_device == Device.DEFAULT:
            target_device_type = tensor_device_type
        elif target_device in [Device.GPU, Device.CUDA]:
            target_device_type = "cuda"
        else:
            target_device_type = target_device.value

        # TODO(swang): Support local P2P transfers if available.
        if target_device_type != "cpu":

            def convert_numpy_to_tensor(np_array):
                if not isinstance(np_array, np.ndarray):
                    # For scalar tensors, create the 0-dim tensor.
                    return torch.tensor(
                        np_array, device=target_device_type, dtype=dtype
                    )
                else:
                    # For non-scalar tensors, view as the original dtype.
                    # It does zero-copy convert np_array inside shared memory to
                    # a tensor. Since we move data to GPU immediately, it is safe.
                    cpu_tensor = torch.from_numpy(np_array).view(dtype)
                    return cpu_tensor.to(device=target_device_type)

            global _TORCH_WARNING_FILTER_ACTIVATE
            # filtering warning messages would be the bottleneck for
            # deserializing torch tensors. Since the warning only prompts once,
            # we would only deal with it for the first time.
            if _TORCH_WARNING_FILTER_ACTIVATE:
                with warnings.catch_warnings():
                    # Since np_array.is_writable is False (it is set by Ray),
                    # this raises a warning. Suppress it.
                    warnings.filterwarnings(
                        "ignore",
                        category=UserWarning,
                        message="The given NumPy array is not writable",
                    )
                    gpu_tensor = convert_numpy_to_tensor(np_array)
                _TORCH_WARNING_FILTER_ACTIVATE = False
            else:
                gpu_tensor = convert_numpy_to_tensor(np_array)

            return gpu_tensor

        # TODO(swang): Use zero-copy from_numpy() if np_array.flags.writeable
        # is True. This is safe to set when deserializing np_array if the
        # upstream task has num_readers=1.
        if not isinstance(np_array, np.ndarray):
            # For scalar tensors, create the 0-dim tensor.
            return torch.tensor(np_array, device=target_device_type, dtype=dtype)
        else:
            # For non-scalar tensors, view as the original dtype.
            return torch.tensor(np_array, device=target_device_type).view(dtype)
