import logging
import sys
import time
from collections import defaultdict
from queue import Queue
from threading import Event, Lock, Thread
from typing import TYPE_CHECKING, Any, Dict, Iterator, Union

import grpc

import ray
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray._private.client_mode_hook import disable_client_hook
from ray.util.client.common import (
    CLIENT_SERVER_MAX_THREADS,
    OrderedResponseCache,
    _propagate_error_in_context,
)
from ray.util.client.server.server_pickler import loads_from_client
from ray.util.debug import log_once

if TYPE_CHECKING:
    from ray.util.client.server.server import RayletServicer

logger = logging.getLogger(__name__)

QUEUE_JOIN_SECONDS = 10


def _get_reconnecting_from_context(context: Any) -> bool:
    """
    Get `reconnecting` from gRPC metadata, or False if missing.
    """
    metadata = dict(context.invocation_metadata())
    val = metadata.get("reconnecting")
    if val is None or val not in ("True", "False"):
        logger.error(
            f'Client connecting with invalid value for "reconnecting": {val}, '
            "This may be because you have a mismatched client and server "
            "version."
        )
        return False
    return val == "True"


def _should_cache(req: ray_client_pb2.DataRequest) -> bool:
    """
    Returns True if the response should to the given request should be cached,
    false otherwise. At the moment the only requests we do not cache are:
        - asynchronous gets: These arrive out of order. Skipping caching here
            is fine, since repeating an async get is idempotent
        - acks: Repeating acks is idempotent
        - clean up requests: Also idempotent, and client has likely already
             wrapped up the data connection by this point.
        - puts: We should only cache when we receive the final chunk, since
             any earlier chunks won't generate a response
        - tasks: We should only cache when we receive the final chunk,
             since any earlier chunks won't generate a response
    """
    req_type = req.WhichOneof("type")
    if req_type == "get" and req.get.asynchronous:
        return False
    if req_type == "put":
        return req.put.chunk_id == req.put.total_chunks - 1
    if req_type == "task":
        return req.task.chunk_id == req.task.total_chunks - 1
    return req_type not in ("acknowledge", "connection_cleanup")


def fill_queue(
    grpc_input_generator: Iterator[ray_client_pb2.DataRequest],
    output_queue: "Queue[Union[ray_client_pb2.DataRequest, ray_client_pb2.DataResponse]]",  # noqa: E501
) -> None:
    """
    Pushes incoming requests to a shared output_queue.
    """
    try:
        for req in grpc_input_generator:
            output_queue.put(req)
    except grpc.RpcError as e:
        logger.debug(
            "closing dataservicer reader thread "
            f"grpc error reading request_iterator: {e}"
        )
    finally:
        # Set the sentinel value for the output_queue
        output_queue.put(None)


class ChunkCollector:
    """
    Helper class for collecting chunks from PutObject or ClientTask messages
    """

    def __init__(self):
        self.curr_req_id = None
        self.last_seen_chunk_id = -1
        self.data = bytearray()

    def add_chunk(
        self,
        req: ray_client_pb2.DataRequest,
        chunk: Union[ray_client_pb2.PutRequest, ray_client_pb2.ClientTask],
    ):
        if self.curr_req_id is not None and self.curr_req_id != req.req_id:
            raise RuntimeError(
                "Expected to receive a chunk from request with id "
                f"{self.curr_req_id}, but found {req.req_id} instead."
            )
        self.curr_req_id = req.req_id
        next_chunk = self.last_seen_chunk_id + 1
        if chunk.chunk_id < next_chunk:
            # Repeated chunk, ignore
            return
        if chunk.chunk_id > next_chunk:
            raise RuntimeError(
                f"A chunk {chunk.chunk_id} of request {req.req_id} was "
                "received out of order."
            )
        elif chunk.chunk_id == self.last_seen_chunk_id + 1:
            self.data.extend(chunk.data)
            self.last_seen_chunk_id = chunk.chunk_id
        return chunk.chunk_id + 1 == chunk.total_chunks

    def reset(self):
        self.curr_req_id = None
        self.last_seen_chunk_id = -1
        self.data = bytearray()


class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
    def __init__(self, basic_service: "RayletServicer"):
        self.basic_service = basic_service
        self.clients_lock = Lock()
        self.num_clients = 0  # guarded by self.clients_lock
        # dictionary mapping client_id's to the last time they connected
        self.client_last_seen: Dict[str, float] = {}
        # dictionary mapping client_id's to their reconnect grace periods
        self.reconnect_grace_periods: Dict[str, float] = {}
        # dictionary mapping client_id's to their response cache
        self.response_caches: Dict[str, OrderedResponseCache] = defaultdict(
            OrderedResponseCache
        )
        # stopped event, useful for signals that the server is shut down
        self.stopped = Event()
        # Helper for collecting chunks from PutObject calls. Assumes that
        # that put requests from different objects aren't interleaved.
        self.put_request_chunk_collector = ChunkCollector()
        # Helper for collecting chunks from ClientTask calls. Assumes that
        # schedule requests from different remote calls aren't interleaved.
        self.client_task_chunk_collector = ChunkCollector()

    def Datapath(self, request_iterator, context):
        start_time = time.time()
        # set to True if client shuts down gracefully
        cleanup_requested = False
        metadata = dict(context.invocation_metadata())
        client_id = metadata.get("client_id")
        if client_id is None:
            logger.error("Client connecting with no client_id")
            return
        logger.debug(f"New data connection from client {client_id}: ")
        accepted_connection = self._init(client_id, context, start_time)
        response_cache = self.response_caches[client_id]
        # Set to False if client requests a reconnect grace period of 0
        reconnect_enabled = True
        if not accepted_connection:
            return
        try:
            request_queue = Queue()
            queue_filler_thread = Thread(
                target=fill_queue, daemon=True, args=(request_iterator, request_queue)
            )
            queue_filler_thread.start()
            """For non `async get` requests, this loop yields immediately
            For `async get` requests, this loop:
                 1) does not yield, it just continues
                 2) When the result is ready, it yields
            """
            for req in iter(request_queue.get, None):
                if isinstance(req, ray_client_pb2.DataResponse):
                    # Early shortcut if this is the result of an async get.
                    yield req
                    continue

                assert isinstance(req, ray_client_pb2.DataRequest)
                if _should_cache(req) and reconnect_enabled:
                    cached_resp = response_cache.check_cache(req.req_id)
                    if isinstance(cached_resp, Exception):
                        # Cache state is invalid, raise exception
                        raise cached_resp
                    if cached_resp is not None:
                        yield cached_resp
                        continue

                resp = None
                req_type = req.WhichOneof("type")
                if req_type == "init":
                    resp_init = self.basic_service.Init(req.init)
                    resp = ray_client_pb2.DataResponse(
                        init=resp_init,
                    )
                    with self.clients_lock:
                        self.reconnect_grace_periods[
                            client_id
                        ] = req.init.reconnect_grace_period
                        if req.init.reconnect_grace_period == 0:
                            reconnect_enabled = False

                elif req_type == "get":
                    if req.get.asynchronous:
                        get_resp = self.basic_service._async_get_object(
                            req.get, client_id, req.req_id, request_queue
                        )
                        if get_resp is None:
                            # Skip sending a response for this request and
                            # continue to the next requst. The response for
                            # this request will be sent when the object is
                            # ready.
                            continue
                    else:
                        get_resp = self.basic_service._get_object(req.get, client_id)
                    resp = ray_client_pb2.DataResponse(get=get_resp)
                elif req_type == "put":
                    if not self.put_request_chunk_collector.add_chunk(req, req.put):
                        # Put request still in progress
                        continue
                    put_resp = self.basic_service._put_object(
                        self.put_request_chunk_collector.data,
                        req.put.client_ref_id,
                        client_id,
                        req.put.owner_id,
                    )
                    self.put_request_chunk_collector.reset()
                    resp = ray_client_pb2.DataResponse(put=put_resp)
                elif req_type == "release":
                    released = []
                    for rel_id in req.release.ids:
                        rel = self.basic_service.release(client_id, rel_id)
                        released.append(rel)
                    resp = ray_client_pb2.DataResponse(
                        release=ray_client_pb2.ReleaseResponse(ok=released)
                    )
                elif req_type == "connection_info":
                    resp = ray_client_pb2.DataResponse(
                        connection_info=self._build_connection_response()
                    )
                elif req_type == "prep_runtime_env":
                    with self.clients_lock:
                        resp_prep = self.basic_service.PrepRuntimeEnv(
                            req.prep_runtime_env
                        )
                        resp = ray_client_pb2.DataResponse(prep_runtime_env=resp_prep)
                elif req_type == "connection_cleanup":
                    cleanup_requested = True
                    cleanup_resp = ray_client_pb2.ConnectionCleanupResponse()
                    resp = ray_client_pb2.DataResponse(connection_cleanup=cleanup_resp)
                elif req_type == "acknowledge":
                    # Clean up acknowledged cache entries
                    response_cache.cleanup(req.acknowledge.req_id)
                    continue
                elif req_type == "task":
                    with self.clients_lock:
                        task = req.task
                        if not self.client_task_chunk_collector.add_chunk(req, task):
                            # Not all serialized arguments have arrived
                            continue
                        arglist, kwargs = loads_from_client(
                            self.client_task_chunk_collector.data, self.basic_service
                        )
                        self.client_task_chunk_collector.reset()
                        resp_ticket = self.basic_service.Schedule(
                            req.task, arglist, kwargs, context
                        )
                        resp = ray_client_pb2.DataResponse(task_ticket=resp_ticket)
                        del arglist
                        del kwargs
                elif req_type == "terminate":
                    with self.clients_lock:
                        response = self.basic_service.Terminate(req.terminate, context)
                        resp = ray_client_pb2.DataResponse(terminate=response)
                elif req_type == "list_named_actors":
                    with self.clients_lock:
                        response = self.basic_service.ListNamedActors(
                            req.list_named_actors
                        )
                        resp = ray_client_pb2.DataResponse(list_named_actors=response)
                else:
                    raise Exception(
                        f"Unreachable code: Request type "
                        f"{req_type} not handled in Datapath"
                    )
                resp.req_id = req.req_id
                if _should_cache(req) and reconnect_enabled:
                    response_cache.update_cache(req.req_id, resp)
                yield resp
        except Exception as e:
            logger.exception("Error in data channel:")
            recoverable = _propagate_error_in_context(e, context)
            invalid_cache = response_cache.invalidate(e)
            if not recoverable or invalid_cache:
                context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
                # Connection isn't recoverable, skip cleanup
                cleanup_requested = True
        finally:
            logger.debug(f"Stream is broken with client {client_id}")
            queue_filler_thread.join(QUEUE_JOIN_SECONDS)
            if queue_filler_thread.is_alive():
                logger.error(
                    "Queue filler thread failed to join before timeout: {}".format(
                        QUEUE_JOIN_SECONDS
                    )
                )
            cleanup_delay = self.reconnect_grace_periods.get(client_id)
            if not cleanup_requested and cleanup_delay is not None:
                logger.debug(
                    "Cleanup wasn't requested, delaying cleanup by"
                    f"{cleanup_delay} seconds."
                )
                # Delay cleanup, since client may attempt a reconnect
                # Wait on the "stopped" event in case the grpc server is
                # stopped and we can clean up earlier.
                self.stopped.wait(timeout=cleanup_delay)
            else:
                logger.debug("Cleanup was requested, cleaning up immediately.")
            with self.clients_lock:
                if client_id not in self.client_last_seen:
                    logger.debug("Connection already cleaned up.")
                    # Some other connection has already cleaned up this
                    # this client's session. This can happen if the client
                    # reconnects and then gracefully shut's down immediately.
                    return
                last_seen = self.client_last_seen[client_id]
                if last_seen > start_time:
                    # The client successfully reconnected and updated
                    # last seen some time during the grace period
                    logger.debug("Client reconnected, skipping cleanup")
                    return
                # Either the client shut down gracefully, or the client
                # failed to reconnect within the grace period. Clean up
                # the connection.
                self.basic_service.release_all(client_id)
                del self.client_last_seen[client_id]
                if client_id in self.reconnect_grace_periods:
                    del self.reconnect_grace_periods[client_id]
                if client_id in self.response_caches:
                    del self.response_caches[client_id]
                self.num_clients -= 1
                logger.debug(
                    f"Removed client {client_id}, " f"remaining={self.num_clients}"
                )

                # It's important to keep the Ray shutdown
                # within this locked context or else Ray could hang.
                # NOTE: it is strange to start ray in server.py but shut it
                # down here. Consider consolidating ray lifetime management.
                with disable_client_hook():
                    if self.num_clients == 0:
                        logger.debug("Shutting down ray.")
                        ray.shutdown()

    def _init(self, client_id: str, context: Any, start_time: float):
        """
        Checks if resources allow for another client.
        Returns a boolean indicating if initialization was successful.
        """
        with self.clients_lock:
            reconnecting = _get_reconnecting_from_context(context)
            threshold = int(CLIENT_SERVER_MAX_THREADS / 2)
            if self.num_clients >= threshold:
                logger.warning(
                    f"[Data Servicer]: Num clients {self.num_clients} "
                    f"has reached the threshold {threshold}. "
                    f"Rejecting client: {client_id}. "
                )
                if log_once("client_threshold"):
                    logger.warning(
                        "You can configure the client connection "
                        "threshold by setting the "
                        "RAY_CLIENT_SERVER_MAX_THREADS env var "
                        f"(currently set to {CLIENT_SERVER_MAX_THREADS})."
                    )
                context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED)
                return False
            if reconnecting and client_id not in self.client_last_seen:
                # Client took too long to reconnect, session has been
                # cleaned up.
                context.set_code(grpc.StatusCode.NOT_FOUND)
                context.set_details(
                    "Attempted to reconnect to a session that has already "
                    "been cleaned up."
                )
                return False
            if client_id in self.client_last_seen:
                logger.debug(f"Client {client_id} has reconnected.")
            else:
                self.num_clients += 1
                logger.debug(
                    f"Accepted data connection from {client_id}. "
                    f"Total clients: {self.num_clients}"
                )
            self.client_last_seen[client_id] = start_time
            return True

    def _build_connection_response(self):
        with self.clients_lock:
            cur_num_clients = self.num_clients
        return ray_client_pb2.ConnectionInfoResponse(
            num_clients=cur_num_clients,
            python_version="{}.{}.{}".format(
                sys.version_info[0], sys.version_info[1], sys.version_info[2]
            ),
            ray_version=ray.__version__,
            ray_commit=ray.__commit__,
        )
