import asyncio
import concurrent
import sys
import threading
import time
from dataclasses import dataclass
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    List,
    NamedTuple,
    Optional,
    Tuple,
    Union,
)

import ray
import ray.exceptions
from ray.experimental.channel.accelerator_context import AcceleratorContext
from ray.experimental.channel.communicator import Communicator
from ray.experimental.channel.communicator_handle import CommunicatorHandle
from ray.experimental.channel.serialization_context import _SerializationContext
from ray.util.annotations import DeveloperAPI, PublicAPI

# The context singleton on this process.
_default_context: "Optional[ChannelContext]" = None
_context_lock = threading.Lock()

if TYPE_CHECKING:
    import torch


def retry_and_check_interpreter_exit(f: Callable[[], None]) -> bool:
    """This function is only useful when f contains channel read/write.

    Keep retrying channel read/write inside `f` and check if interpreter exits.
    It is important in case the read/write happens in a separate thread pool.
    See https://github.com/ray-project/ray/pull/47702

    f should a function that doesn't receive any input and return nothing.
    """
    exiting = False
    while True:
        try:
            f()
            break
        except ray.exceptions.RayChannelTimeoutError:
            if sys.is_finalizing():
                # Interpreter exits. We should ignore the error and
                # stop reading so that the thread can join.
                exiting = True
                break

    return exiting


# Holds the input arguments for Compiled Graph
@PublicAPI(stability="alpha")
class CompiledDAGArgs(NamedTuple):
    args: Tuple[Any, ...]
    kwargs: Dict[str, Any]


@PublicAPI(stability="alpha")
class ChannelOutputType:
    def register_custom_serializer(self) -> None:
        """
        Register any custom serializers needed to pass data of this type. This
        method should be run on the reader(s) and writer of a channel, which
        are the driver and/or Ray actors.

        NOTE: When custom serializers are registered with Ray, the registered
        deserializer is shipped with the serialized value and used on the
        receiving end. Therefore, the deserializer function should *not*
        capture state that is meant to be worker-local, such as the worker's
        default device. Instead, these should be extracted from the
        worker-local _SerializationContext.
        """
        pass

    def create_channel(
        self,
        writer: Optional["ray.actor.ActorHandle"],
        reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]],
        driver_actor_id: Optional[str] = None,
    ) -> "ChannelInterface":
        """
        Instantiate a ChannelInterface class that can be used
        to pass data of this type.

        Args:
            writer: The actor that may write to the channel. None signifies the driver.
            reader_and_node_list: A list of tuples, where each tuple contains a reader
                actor handle and the node ID where the actor is located.
            driver_actor_id: If this is a CompositeChannel that is read by a driver and
                that driver is an actual actor, this will be the actor ID of that
                driver actor.
        Returns:
            A ChannelInterface that can be used to pass data
                of this type.
        """
        raise NotImplementedError

    def requires_accelerator(self) -> bool:
        # By default, channels do not require accelerator.
        return False

    def get_custom_communicator(self) -> Optional[Communicator]:
        """
        Return the custom communicator group if one is specified.
        """
        return None

    def set_communicator_id(self, group_id: str) -> None:
        raise NotImplementedError


@DeveloperAPI
@dataclass
class ChannelContext:
    serialization_context = _SerializationContext()
    _torch_available: Optional[bool] = None
    _torch_device: Optional["torch.device"] = None
    _current_stream: Optional["torch.cuda.Stream"] = None

    def __init__(self):
        # Used for the torch.Tensor accelerator transport.
        self.communicators: Dict[str, "Communicator"] = {}
        # Used for driver process to store actors in the communicator.
        self.communicator_handles: Dict[str, "CommunicatorHandle"] = {}

    @staticmethod
    def get_current() -> "ChannelContext":
        """Get or create a singleton context.

        If the context has not yet been created in this process, it will be
        initialized with default settings.
        """

        global _default_context

        with _context_lock:
            if _default_context is None:
                _default_context = ChannelContext()

            return _default_context

    @property
    def torch_available(self) -> bool:
        """
        Check if torch package is available.
        """
        if self._torch_available is not None:
            return self._torch_available

        try:
            import torch  # noqa: F401
        except ImportError:
            self._torch_available = False
            return False
        self._torch_available = True
        return True

    @property
    def torch_device(self) -> "torch.device":
        if self._torch_device is None:
            self._torch_device = AcceleratorContext.get().get_accelerator_devices()[0]

        return self._torch_device

    def set_torch_device(self, device: "torch.device"):
        self._torch_device = device


@PublicAPI(stability="alpha")
class ChannelInterface:
    """
    Abstraction for a transport between a writer actor and some number of
    reader actors.
    """

    def __init__(
        self,
        writer: Optional[ray.actor.ActorHandle],
        readers: List[Optional[ray.actor.ActorHandle]],
        typ: Optional["ChannelOutputType"],
    ):
        """
        Create a channel that can be read and written by a Ray driver or actor.

        Args:
            writer: The actor that may write to the channel. None signifies the driver.
            readers: The actors that may read from the channel. None signifies
                the driver.
            typ: Type information about the values passed through the channel.
        """
        pass

    def ensure_registered_as_writer(self):
        """
        Check whether the process is a valid writer. This method must be idempotent.
        """
        raise NotImplementedError

    def ensure_registered_as_reader(self):
        """
        Check whether the process is a valid reader. This method must be idempotent.
        """
        raise NotImplementedError

    def write(self, value: Any, timeout: Optional[float] = None) -> None:
        """
        Write a value to the channel.

        Blocks if there are still pending readers for the previous value. The
        writer may not write again until the specified number of readers have
        read the value.

        Args:
            value: The value to write.
            timeout: The maximum time in seconds to wait to write the value.
                None means using default timeout, 0 means immediate timeout
                (immediate success or timeout without blocking), -1 means
                infinite timeout (block indefinitely).
        """
        raise NotImplementedError

    def read(self, timeout: Optional[float] = None) -> Any:
        """
        Read the latest value from the channel. This call will block until a
        value is available to read.

        Subsequent calls to read() may *block* if the deserialized object is
        zero-copy (e.g., bytes or a numpy array) *and* the object is still in scope.

        Args:
            timeout: The maximum time in seconds to wait to read the value.
                None means using default timeout, 0 means immediate timeout
                (immediate success or timeout without blocking), -1 means
                infinite timeout (block indefinitely).

        Returns:
            Any: The deserialized value. If the deserialized value is an
            Exception, it will be returned directly instead of being raised.
        """
        raise NotImplementedError

    def close(self) -> None:
        """
        Close this channel. This method must not block and it must be made
        idempotent. Any existing values in the channel may be lost after the
        channel is closed.
        """
        raise NotImplementedError


# Interfaces for channel I/O.
@DeveloperAPI
class ReaderInterface:
    def __init__(
        self,
        input_channels: List[ChannelInterface],
    ):
        assert isinstance(input_channels, list)
        for chan in input_channels:
            assert isinstance(chan, ChannelInterface)

        self._input_channels = input_channels
        self._closed = False
        self._num_reads = 0

        # A list of channels that were not read in the last `read` call
        # because the reader returned immediately when a RayTaskError was found.
        # These channels must be consumed before the next read to avoid reading
        # stale data remaining from the last read.
        self._leftover_channels: List[ChannelInterface] = []

    def get_num_reads(self) -> int:
        return self._num_reads

    def start(self):
        raise NotImplementedError

    def _read_list(self, timeout: Optional[float] = None) -> List[Any]:
        """
        Read a list of values from this reader.

        Args:
            timeout: The maximum time in seconds to wait for reading.
                None means using default timeout which is infinite, 0 means immediate
                timeout (immediate success or timeout without blocking), -1 means
                infinite timeout (block indefinitely).

        """
        raise NotImplementedError

    def read(self, timeout: Optional[float] = None) -> List[Any]:
        """
        Read from this reader.

        Args:
            timeout: The maximum time in seconds to wait for reading.
                None means using default timeout, 0 means immediate timeout
                (immediate success or timeout without blocking), -1 means
                infinite timeout (block indefinitely).
        """
        assert (
            timeout is None or timeout >= 0 or timeout == -1
        ), "Timeout must be non-negative or -1."
        outputs = self._read_list(timeout)
        self._num_reads += 1
        return outputs

    def close(self) -> None:
        self._closed = True
        for channel in self._input_channels:
            channel.close()

    def _consume_leftover_channels_if_needed(
        self, timeout: Optional[float] = None
    ) -> None:
        # Consume the channels that were not read in the last `read` call because a
        # RayTaskError was returned from another channel. If we don't do this, the
        # read operation will read stale versions of the object refs.
        #
        # If a RayTaskError is returned from a leftover channel, it will be ignored.
        # If a read operation times out, a RayChannelTimeoutError exception will be
        # raised.
        #
        # TODO(kevin85421): Currently, a DAG with NCCL channels and fast fail enabled
        # may not be reusable. Revisit this in the future.
        for c in self._leftover_channels:
            start_time = time.monotonic()
            c.read(timeout)
            if timeout is not None:
                timeout -= time.monotonic() - start_time
                timeout = max(timeout, 0)
        self._leftover_channels = []


@DeveloperAPI
class SynchronousReader(ReaderInterface):
    def __init__(
        self,
        input_channels: List[ChannelInterface],
    ):
        super().__init__(input_channels)

    def start(self):
        pass

    def _read_list(self, timeout: Optional[float] = None) -> List[Any]:
        self._consume_leftover_channels_if_needed(timeout)
        # We don't update `remaining_timeout` here because in the worst case,
        # consuming leftover channels requires reading all `_input_channels`,
        # which users expect to complete within the original `timeout`. Updating
        # `remaining_timeout` could cause unexpected timeouts in subsequent read
        # operations.

        # It is a special case that `timeout` is set to 0, which means
        # read once for each channel.
        is_zero_timeout = timeout == 0

        results = [None for _ in range(len(self._input_channels))]
        if timeout is None or timeout == -1:
            timeout = float("inf")
        timeout_point = time.monotonic() + timeout
        remaining_timeout = timeout

        from ray.dag import DAGContext

        ctx = DAGContext.get_current()
        iteration_timeout = ctx.read_iteration_timeout

        # Iterate over the input channels with a shorter timeout for each iteration
        # to detect RayTaskError early and fail fast.
        done_channels = set()
        while len(done_channels) < len(self._input_channels):
            for i, c in enumerate(self._input_channels):
                if c in done_channels:
                    continue
                try:
                    result = c.read(min(remaining_timeout, iteration_timeout))
                    results[i] = result
                    done_channels.add(c)
                    if isinstance(result, ray.exceptions.RayTaskError):
                        # If we raise an exception immediately, it will be considered
                        # as a system error which will cause the execution loop to
                        # exit. Hence, return immediately and let `_process_return_vals`
                        # handle the exception.
                        #
                        # Return a list of RayTaskError so that the caller will not
                        # get an undefined partial result.
                        self._leftover_channels = [
                            c for c in self._input_channels if c not in done_channels
                        ]
                        return [result for _ in range(len(self._input_channels))]
                except ray.exceptions.RayChannelTimeoutError as e:
                    remaining_timeout = max(timeout_point - time.monotonic(), 0)
                    if remaining_timeout == 0:
                        raise e
                    continue

                remaining_timeout = max(timeout_point - time.monotonic(), 0)
                if remaining_timeout == 0 and not is_zero_timeout:
                    raise ray.exceptions.RayChannelTimeoutError(
                        f"Cannot read all channels within {timeout} seconds"
                    )
        return results

    def release_channel_buffers(self, timeout: Optional[float] = None) -> None:
        for c in self._input_channels:
            start_time = time.monotonic()
            assert hasattr(
                c, "release_buffer"
            ), "release_buffer() is only supported for shared memory channel "
            "(e.g., Channel, BufferedSharedMemoryChannel, CompositeChannel) "
            "and used between the last actor and the driver, but got a channel"
            f" of type {type(c)}."
            c.release_buffer(timeout)
            if timeout is not None:
                timeout -= time.monotonic() - start_time
                timeout = max(timeout, 0)


@DeveloperAPI
class AwaitableBackgroundReader(ReaderInterface):
    """
    Asyncio-compatible channel reader.

    The reader is constructed with an async queue of futures whose values it
    will fulfill. It uses a threadpool to execute the blocking calls to read
    from the input channel(s).
    """

    def __init__(
        self,
        input_channels: List[ChannelInterface],
        fut_queue: asyncio.Queue,
    ):
        super().__init__(input_channels)
        self._fut_queue = fut_queue
        self._background_task = None
        self._background_task_executor = concurrent.futures.ThreadPoolExecutor(
            max_workers=1, thread_name_prefix="channel.AwaitableBackgroundReader"
        )

    def start(self):
        self._background_task = asyncio.ensure_future(self.run())

    def _run(self):
        # Give it a default timeout 60 seconds to release the buffers
        # of the channels that were not read in the last `read` call.
        self._consume_leftover_channels_if_needed(60)

        results = [None for _ in range(len(self._input_channels))]

        from ray.dag import DAGContext

        ctx = DAGContext.get_current()
        iteration_timeout = ctx.read_iteration_timeout

        done_channels = set()
        while len(done_channels) < len(self._input_channels):
            for i, c in enumerate(self._input_channels):
                if c in done_channels:
                    continue
                try:
                    result = c.read(iteration_timeout)
                    results[i] = result
                    done_channels.add(c)
                    if isinstance(result, ray.exceptions.RayTaskError):
                        self._leftover_channels = [
                            c for c in self._input_channels if c not in done_channels
                        ]
                        return [result for _ in range(len(self._input_channels))]
                except ray.exceptions.RayChannelTimeoutError:
                    pass
                if sys.is_finalizing():
                    return results
        return results

    async def run(self):
        loop = asyncio.get_running_loop()
        while not self._closed:
            res, fut = await asyncio.gather(
                loop.run_in_executor(self._background_task_executor, self._run),
                self._fut_queue.get(),
                return_exceptions=True,
            )

            # Set the result on the main thread.
            fut.set_result(res)
            # NOTE(swang): If the object is zero-copy deserialized, then it
            # will stay in scope as long as ret and the future are in scope.
            # Therefore, we must delete both here after fulfilling the future.
            del res
            del fut

    def close(self):
        super().close()
        self._background_task_executor.shutdown(cancel_futures=True)
        self._background_task.cancel()


@DeveloperAPI
class WriterInterface:
    def __init__(
        self,
        output_channels: List[ChannelInterface],
        output_idxs: List[Optional[Union[int, str]]],
        is_input=False,
    ):
        """
        Initialize the writer.

        Args:
            output_channels: The output channels to write to.
            output_idxs: The indices of the values to write to each channel.
                This has the same length as `output_channels`. If `is_input` is True,
                the index can be an integer or a string to retrieve the corresponding
                value from `args` or `kwargs` in the DAG's input. If `is_input`
                is False, the entire value is written if the index is None. Otherwise,
                the value at the specified index in the tuple is written.
            is_input: Whether the writer is DAG input writer or not.
        """

        assert len(output_channels) == len(output_idxs)
        self._output_channels = output_channels
        self._output_idxs = output_idxs
        self._closed = False
        self._num_writes = 0
        self._is_input = is_input

    def get_num_writes(self) -> int:
        return self._num_writes

    def start(self):
        raise NotImplementedError()

    def write(self, val: Any, timeout: Optional[float] = None) -> None:
        """
        Write the value.

        Args:
            timeout: The maximum time in seconds to wait for writing. 0 means
                immediate timeout (immediate success or timeout without blocking).
                -1 and None mean infinite timeout (blocks indefinitely).
        """
        raise NotImplementedError()

    def close(self) -> None:
        self._closed = True
        for channel in self._output_channels:
            channel.close()


def _adapt(raw_args: Any, key: Optional[Union[int, str]], is_input: bool):
    """
    Adapt the raw arguments to the key. If `is_input` is True, this method will
    retrieve the value from the input data for an InputAttributeNode. Otherwise, it
    will retrieve either a partial value or the entire value from the output of
    a ClassMethodNode.

    Args:
        raw_args: The raw arguments to adapt.
        key: The key to adapt.
        is_input: Whether the writer is DAG input writer or not.
    """
    if is_input:
        if not isinstance(raw_args, CompiledDAGArgs):
            # Fast path for a single input.
            return raw_args
        else:
            args = raw_args.args
            kwargs = raw_args.kwargs

        if isinstance(key, int):
            return args[key]
        else:
            return kwargs[key]
    else:
        if key is not None:
            return raw_args[key]
        else:
            return raw_args


@DeveloperAPI
class SynchronousWriter(WriterInterface):
    def start(self):
        for channel in self._output_channels:
            channel.ensure_registered_as_writer()

    def write(self, val: Any, timeout: Optional[float] = None) -> None:
        # If it is an exception, there's only 1 return value.
        # We have to send the same data to all channels.
        if isinstance(val, Exception):
            if len(self._output_channels) > 1:
                val = tuple(val for _ in range(len(self._output_channels)))

        if not self._is_input:
            if len(self._output_channels) > 1:
                if not isinstance(val, tuple):
                    raise ValueError(
                        f"Expected a tuple of {len(self._output_channels)} outputs, "
                        f"but got {type(val)}"
                    )
                if len(val) != len(self._output_channels):
                    raise ValueError(
                        f"Expected {len(self._output_channels)} outputs, but got "
                        f"{len(val)} outputs"
                    )

        for i, channel in enumerate(self._output_channels):
            idx = self._output_idxs[i]
            val_i = _adapt(val, idx, self._is_input)
            channel.write(val_i, timeout)
        self._num_writes += 1


@DeveloperAPI
class AwaitableBackgroundWriter(WriterInterface):
    def __init__(
        self,
        output_channels: List[ChannelInterface],
        output_idxs: List[Optional[Union[int, str]]],
        is_input=False,
    ):
        super().__init__(output_channels, output_idxs, is_input=is_input)
        self._queue = asyncio.Queue()
        self._background_task = None
        self._background_task_executor = concurrent.futures.ThreadPoolExecutor(
            max_workers=1, thread_name_prefix="channel.AwaitableBackgroundWriter"
        )

    def start(self):
        for channel in self._output_channels:
            channel.ensure_registered_as_writer()
        self._background_task = asyncio.ensure_future(self.run())

    def _run(self, res):
        if not self._is_input:
            if len(self._output_channels) > 1:
                if not isinstance(res, tuple):
                    raise ValueError(
                        f"Expected a tuple of {len(self._output_channels)} outputs, "
                        f"but got {type(res)}"
                    )
                if len(res) != len(self._output_channels):
                    raise ValueError(
                        f"Expected {len(self._output_channels)} outputs, but got "
                        f"{len(res)} outputs"
                    )

        for i, channel in enumerate(self._output_channels):
            idx = self._output_idxs[i]
            res_i = _adapt(res, idx, self._is_input)
            exiting = retry_and_check_interpreter_exit(
                lambda: channel.write(res_i, timeout=1)
            )
            if exiting:
                break

    async def run(self):
        loop = asyncio.get_event_loop()
        while True:
            res = await self._queue.get()
            await loop.run_in_executor(self._background_task_executor, self._run, res)

    async def write(self, val: Any) -> None:
        if self._closed:
            raise RuntimeError("DAG execution cancelled")
        await self._queue.put(val)
        self._num_writes += 1

    def close(self):
        self._background_task.cancel()
        super().close()
