import asyncio
import logging
import random
from collections import deque
from typing import List, Tuple

import grpc
from grpc import aio as aiogrpc

import ray._private.gcs_utils as gcs_utils
from ray._common.utils import get_or_create_event_loop
from ray.core.generated import (
    gcs_pb2,
    gcs_service_pb2,
    gcs_service_pb2_grpc,
    pubsub_pb2,
)

logger = logging.getLogger(__name__)


class _SubscriberBase:
    def __init__(self, worker_id: bytes = None):
        self._worker_id = worker_id
        # self._subscriber_id needs to match the binary format of a random
        # SubscriberID / UniqueID, which is 28 (kUniqueIDSize) random bytes.
        self._subscriber_id = bytes(bytearray(random.getrandbits(8) for _ in range(28)))
        self._last_batch_size = 0
        self._max_processed_sequence_id = 0
        self._publisher_id = b""

    # Batch size of the result from last poll. Used to indicate whether the
    # subscriber can keep up.
    @property
    def last_batch_size(self):
        return self._last_batch_size

    def _subscribe_request(self, channel):
        cmd = pubsub_pb2.Command(channel_type=channel, subscribe_message={})
        req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
            subscriber_id=self._subscriber_id, sender_id=self._worker_id, commands=[cmd]
        )
        return req

    def _poll_request(self):
        return gcs_service_pb2.GcsSubscriberPollRequest(
            subscriber_id=self._subscriber_id,
            max_processed_sequence_id=self._max_processed_sequence_id,
            publisher_id=self._publisher_id,
        )

    def _unsubscribe_request(self, channels):
        req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
            subscriber_id=self._subscriber_id, sender_id=self._worker_id, commands=[]
        )
        for channel in channels:
            req.commands.append(
                pubsub_pb2.Command(channel_type=channel, unsubscribe_message={})
            )
        return req

    @staticmethod
    def _should_terminate_polling(e: grpc.RpcError) -> None:
        # Caller only expects polling to be terminated after deadline exceeded.
        if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
            return True
        # Could be a temporary connection issue. Suppress error.
        # TODO: reconnect GRPC channel?
        if e.code() == grpc.StatusCode.UNAVAILABLE:
            return True
        return False


class _AioSubscriber(_SubscriberBase):
    """Async io subscriber to GCS.

    Usage example common to Aio subscribers:
        subscriber = GcsAioXxxSubscriber(address="...")
        await subscriber.subscribe()
        while running:
            ...... = await subscriber.poll()
            ......
        await subscriber.close()
    """

    def __init__(
        self,
        pubsub_channel_type,
        worker_id: bytes = None,
        address: str = None,
        channel: aiogrpc.Channel = None,
    ):
        super().__init__(worker_id)

        if address:
            assert channel is None, "address and channel cannot both be specified"
            channel = gcs_utils.create_gcs_channel(address, aio=True)
        else:
            assert channel is not None, "One of address and channel must be specified"
        # GRPC stub to GCS pubsub.
        self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)

        # Type of the channel.
        self._channel = pubsub_channel_type
        # A queue of received PubMessage.
        self._queue = deque()
        # Indicates whether the subscriber has closed.
        self._close = asyncio.Event()

    async def subscribe(self) -> None:
        """Registers a subscription for the subscriber's channel type.

        Before the registration, published messages in the channel will not be
        saved for the subscriber.
        """
        if self._close.is_set():
            return
        req = self._subscribe_request(self._channel)
        await self._stub.GcsSubscriberCommandBatch(req, timeout=30)

    async def _poll_call(self, req, timeout=None):
        # Wrap GRPC _AioCall as a coroutine.
        return await self._stub.GcsSubscriberPoll(req, timeout=timeout)

    async def _poll(self, timeout=None) -> None:
        while len(self._queue) == 0:
            req = self._poll_request()
            poll = get_or_create_event_loop().create_task(
                self._poll_call(req, timeout=timeout)
            )
            close = get_or_create_event_loop().create_task(self._close.wait())
            done, others = await asyncio.wait(
                [poll, close], timeout=timeout, return_when=asyncio.FIRST_COMPLETED
            )
            # Cancel the other task if needed to prevent memory leak.
            other_task = others.pop()
            if not other_task.done():
                other_task.cancel()
            if poll not in done or close in done:
                # Request timed out or subscriber closed.
                break
            try:
                self._last_batch_size = len(poll.result().pub_messages)
                if poll.result().publisher_id != self._publisher_id:
                    if self._publisher_id != "":
                        logger.debug(
                            f"replied publisher_id {poll.result().publisher_id}"
                            f"different from {self._publisher_id}, this should "
                            "only happens during gcs failover."
                        )
                    self._publisher_id = poll.result().publisher_id
                    self._max_processed_sequence_id = 0
                for msg in poll.result().pub_messages:
                    if msg.sequence_id <= self._max_processed_sequence_id:
                        logger.warning(f"Ignoring out of order message {msg}")
                        continue
                    self._max_processed_sequence_id = msg.sequence_id
                    self._queue.append(msg)
            except grpc.RpcError as e:
                if self._should_terminate_polling(e):
                    return
                raise

    async def close(self) -> None:
        """Closes the subscriber and its active subscription."""

        # Mark close to terminate inflight polling and prevent future requests.
        if self._close.is_set():
            return
        self._close.set()
        req = self._unsubscribe_request(channels=[self._channel])
        try:
            await self._stub.GcsSubscriberCommandBatch(req, timeout=5)
        except Exception:
            pass
        self._stub = None


class GcsAioResourceUsageSubscriber(_AioSubscriber):
    def __init__(
        self,
        worker_id: bytes = None,
        address: str = None,
        channel: grpc.Channel = None,
    ):
        super().__init__(
            pubsub_pb2.RAY_NODE_RESOURCE_USAGE_CHANNEL, worker_id, address, channel
        )

    async def poll(self, timeout=None) -> Tuple[bytes, str]:
        """Polls for new resource usage message.

        Returns:
            A tuple of string reporter ID and resource usage json string.
        """
        await self._poll(timeout=timeout)
        return self._pop_resource_usage(self._queue)

    @staticmethod
    def _pop_resource_usage(queue):
        if len(queue) == 0:
            return None, None
        msg = queue.popleft()
        return msg.key_id.decode(), msg.node_resource_usage_message.json


class GcsAioActorSubscriber(_AioSubscriber):
    def __init__(
        self,
        worker_id: bytes = None,
        address: str = None,
        channel: grpc.Channel = None,
    ):
        super().__init__(pubsub_pb2.GCS_ACTOR_CHANNEL, worker_id, address, channel)

    @property
    def queue_size(self):
        return len(self._queue)

    async def poll(
        self, batch_size, timeout=None
    ) -> List[Tuple[bytes, gcs_pb2.ActorTableData]]:
        """Polls for new actor message.

        Returns:
            A list of tuples of binary actor ID and actor table data.
        """
        await self._poll(timeout=timeout)
        return self._pop_actors(self._queue, batch_size=batch_size)

    @staticmethod
    def _pop_actors(queue, batch_size):
        if len(queue) == 0:
            return []
        popped = 0
        msgs = []
        while len(queue) > 0 and popped < batch_size:
            msg = queue.popleft()
            msgs.append((msg.key_id, msg.actor_message))
            popped += 1
        return msgs


class GcsAioNodeInfoSubscriber(_AioSubscriber):
    def __init__(
        self,
        worker_id: bytes = None,
        address: str = None,
        channel: grpc.Channel = None,
    ):
        super().__init__(pubsub_pb2.GCS_NODE_INFO_CHANNEL, worker_id, address, channel)

    async def poll(
        self, batch_size, timeout=None
    ) -> List[Tuple[bytes, gcs_pb2.GcsNodeInfo]]:
        """Polls for new node info message.

        Returns:
            A list of tuples of (node_id, GcsNodeInfo).
        """
        await self._poll(timeout=timeout)
        return self._pop_node_infos(self._queue, batch_size=batch_size)

    @staticmethod
    def _pop_node_infos(queue, batch_size):
        if len(queue) == 0:
            return []
        popped = 0
        msgs = []
        while len(queue) > 0 and popped < batch_size:
            msg = queue.popleft()
            msgs.append((msg.key_id, msg.node_info_message))
            popped += 1
        return msgs
