import asyncio
import inspect
import logging
import random
import time
from collections.abc import Sequence
from functools import wraps
from typing import Callable, Dict, List, Optional, Tuple, Union

import ray
from ray.actor import ActorHandle
from ray.serve._private.application_state import StatusOverview
from ray.serve._private.build_app import BuiltApplication
from ray.serve._private.common import (
    DeploymentID,
    DeploymentStatus,
    DeploymentStatusInfo,
    RequestRoutingInfo,
)
from ray.serve._private.constants import (
    CLIENT_CHECK_CREATION_POLLING_INTERVAL_S,
    CLIENT_POLLING_INTERVAL_S,
    HTTP_PROXY_TIMEOUT,
    MAX_CACHED_HANDLES,
    SERVE_DEFAULT_APP_NAME,
    SERVE_LOGGER_NAME,
)
from ray.serve._private.controller import ServeController
from ray.serve._private.deploy_utils import get_deploy_args
from ray.serve._private.deployment_info import DeploymentInfo
from ray.serve._private.http_util import ASGIAppReplicaWrapper
from ray.serve._private.utils import get_random_string
from ray.serve.config import HTTPOptions
from ray.serve.exceptions import RayServeException
from ray.serve.generated.serve_pb2 import (
    ApplicationArgs,
    DeploymentArgs,
    DeploymentRoute,
    DeploymentStatusInfo as DeploymentStatusInfoProto,
    StatusOverview as StatusOverviewProto,
)
from ray.serve.handle import DeploymentHandle
from ray.serve.schema import (
    ApplicationStatus,
    LoggingConfig,
    ServeApplicationSchema,
    ServeDeploySchema,
)

logger = logging.getLogger(SERVE_LOGGER_NAME)


def _ensure_connected(f: Callable) -> Callable:
    @wraps(f)
    def check(self, *args, **kwargs):
        if self._shutdown:
            raise RayServeException("Client has already been shut down.")
        return f(self, *args, **kwargs)

    return check


class ServeControllerClient:
    def __init__(
        self,
        controller: ActorHandle,
    ):
        self._controller: ServeController = controller
        self._shutdown = False
        self._http_config: HTTPOptions = ray.get(controller.get_http_config.remote())
        self._root_url = ray.get(controller.get_root_url.remote())

        # Each handle has the overhead of long poll client, therefore cached.
        self.handle_cache = dict()
        self._evicted_handle_keys = set()

    @property
    def root_url(self):
        return self._root_url

    @property
    def http_config(self):
        return self._http_config

    def __reduce__(self):
        raise RayServeException(("Ray Serve client cannot be serialized."))

    def shutdown_cached_handles(self):
        """Shuts down all cached handles.

        Remove the reference to the cached handles so that they can be
        garbage collected.
        """
        for cache_key in list(self.handle_cache):
            self.handle_cache[cache_key].shutdown()
            del self.handle_cache[cache_key]

    async def shutdown_cached_handles_async(self):
        """Shuts down all cached handles asynchronously.

        Remove the reference to the cached handles so that they can be
        garbage collected.
        """

        async def shutdown_task(cache_key):
            await self.handle_cache[cache_key].shutdown_async()
            del self.handle_cache[cache_key]

        await asyncio.gather(
            *[shutdown_task(cache_key) for cache_key in list(self.handle_cache)]
        )

    def shutdown(self, timeout_s: float = 30.0) -> None:
        """Completely shut down the connected Serve instance.

        Shuts down all processes and deletes all state associated with the
        instance.
        """
        self.shutdown_cached_handles()

        if ray.is_initialized() and not self._shutdown:
            try:
                ray.get(self._controller.graceful_shutdown.remote(), timeout=timeout_s)
            except ray.exceptions.RayActorError:
                # Controller has been shut down.
                pass
            except TimeoutError:
                logger.warning(
                    f"Controller failed to shut down within {timeout_s}s. "
                    "Check controller logs for more details."
                )
            self._shutdown = True

    async def shutdown_async(self, timeout_s: float = 30.0) -> None:
        """Completely shut down the connected Serve instance.

        Shuts down all processes and deletes all state associated with the
        instance.
        """
        await self.shutdown_cached_handles_async()

        if ray.is_initialized() and not self._shutdown:
            try:
                await asyncio.wait_for(
                    self._controller.graceful_shutdown.remote(), timeout=timeout_s
                )
            except ray.exceptions.RayActorError:
                # Controller has been shut down.
                pass
            except TimeoutError:
                logger.warning(
                    f"Controller failed to shut down within {timeout_s}s. "
                    "Check controller logs for more details."
                )
            self._shutdown = True

    def _wait_for_deployment_healthy(self, name: str, timeout_s: int = -1):
        """Waits for the named deployment to enter "HEALTHY" status.

        Raises RuntimeError if the deployment enters the "UNHEALTHY" status
        instead.

        Raises TimeoutError if this doesn't happen before timeout_s.
        """
        start = time.time()
        while time.time() - start < timeout_s or timeout_s < 0:
            status_bytes = ray.get(self._controller.get_deployment_status.remote(name))

            if status_bytes is None:
                raise RuntimeError(
                    f"Waiting for deployment {name} to be HEALTHY, "
                    "but deployment doesn't exist."
                )

            status = DeploymentStatusInfo.from_proto(
                DeploymentStatusInfoProto.FromString(status_bytes)
            )

            if status.status == DeploymentStatus.HEALTHY:
                break
            elif status.status == DeploymentStatus.UNHEALTHY:
                raise RuntimeError(
                    f"Deployment {name} is UNHEALTHY: " f"{status.message}"
                )
            else:
                # Guard against new unhandled statuses being added.
                assert status.status == DeploymentStatus.UPDATING

            logger.debug(
                f"Waiting for {name} to be healthy, current status: "
                f"{status.status}."
            )
            time.sleep(CLIENT_POLLING_INTERVAL_S)
        else:
            raise TimeoutError(
                f"Deployment {name} did not become HEALTHY after {timeout_s}s."
            )

    def _wait_for_deployment_deleted(
        self, name: str, app_name: str, timeout_s: int = 60
    ):
        """Waits for the named deployment to be shut down and deleted.

        Raises TimeoutError if this doesn't happen before timeout_s.
        """
        start = time.time()
        while time.time() - start < timeout_s:
            curr_status_bytes = ray.get(
                self._controller.get_deployment_status.remote(name)
            )
            if curr_status_bytes is None:
                break
            curr_status = DeploymentStatusInfo.from_proto(
                DeploymentStatusInfoProto.FromString(curr_status_bytes)
            )
            logger.debug(
                f"Waiting for {name} to be deleted, current status: {curr_status}."
            )
            time.sleep(CLIENT_POLLING_INTERVAL_S)
        else:
            raise TimeoutError(f"Deployment {name} wasn't deleted after {timeout_s}s.")

    def _wait_for_deployment_created(
        self, deployment_name: str, app_name: str, timeout_s: int = -1
    ):
        """Waits for the named deployment to be created.

        A deployment being created simply means that its been registered
        with the deployment state manager. The deployment state manager
        will then continue to reconcile the deployment towards its
        target state.

        Raises TimeoutError if this doesn't happen before timeout_s.
        """

        start = time.time()
        while time.time() - start < timeout_s or timeout_s < 0:
            status_bytes = ray.get(
                self._controller.get_deployment_status.remote(deployment_name, app_name)
            )

            if status_bytes is not None:
                break

            logger.debug(
                f"Waiting for deployment '{deployment_name}' in application "
                f"'{app_name}' to be created."
            )
            time.sleep(CLIENT_CHECK_CREATION_POLLING_INTERVAL_S)
        else:
            raise TimeoutError(
                f"Deployment '{deployment_name}' in application '{app_name}' "
                f"did not become HEALTHY after {timeout_s}s."
            )

    def _wait_for_application_running(self, name: str, timeout_s: int = -1):
        """Waits for the named application to enter "RUNNING" status.

        Raises:
            RuntimeError: if the application enters the "DEPLOY_FAILED" status instead.
            TimeoutError: if this doesn't happen before timeout_s.
        """
        start = time.time()
        while time.time() - start < timeout_s or timeout_s < 0:
            status_bytes = ray.get(self._controller.get_serve_status.remote(name))
            if status_bytes is None:
                raise RuntimeError(
                    f"Waiting for application {name} to be RUNNING, "
                    "but application doesn't exist."
                )

            status = StatusOverview.from_proto(
                StatusOverviewProto.FromString(status_bytes)
            )

            if status.app_status.status == ApplicationStatus.RUNNING:
                break
            elif status.app_status.status == ApplicationStatus.DEPLOY_FAILED:
                raise RuntimeError(
                    f"Deploying application {name} failed: {status.app_status.message}"
                )

            logger.debug(
                f"Waiting for {name} to be RUNNING, current status: "
                f"{status.app_status.status}."
            )
            time.sleep(CLIENT_POLLING_INTERVAL_S)
        else:
            raise TimeoutError(
                f"Application {name} did not become RUNNING after {timeout_s}s."
            )

    @_ensure_connected
    def wait_for_proxies_serving(
        self, wait_for_applications_running: bool = True
    ) -> None:
        """Wait for the proxies to be ready to serve requests."""
        proxy_handles = ray.get(self._controller.get_proxies.remote())
        serving_refs = [
            handle.serving.remote(
                wait_for_applications_running=wait_for_applications_running
            )
            for handle in proxy_handles.values()
        ]

        done, pending = ray.wait(
            serving_refs,
            timeout=HTTP_PROXY_TIMEOUT,
            num_returns=len(serving_refs),
        )

        if len(pending) > 0:
            raise TimeoutError(f"Proxies not available after {HTTP_PROXY_TIMEOUT}s.")

        # Ensure the proxies are either serving or dead.
        for ref in done:
            try:
                ray.get(ref, timeout=1)
            except ray.exceptions.RayActorError:
                pass
            except Exception:
                raise TimeoutError(
                    f"Proxies not available after {HTTP_PROXY_TIMEOUT}s."
                )

    @_ensure_connected
    def deploy_applications(
        self,
        built_apps: Sequence[BuiltApplication],
        *,
        wait_for_ingress_deployment_creation: bool = True,
        wait_for_applications_running: bool = True,
    ) -> List[DeploymentHandle]:
        name_to_deployment_args_list = {}
        name_to_application_args = {}
        for app in built_apps:
            deployment_args_list = []
            for deployment in app.deployments:
                if deployment.logging_config is None and app.logging_config:
                    deployment = deployment.options(logging_config=app.logging_config)

                is_ingress = deployment.name == app.ingress_deployment_name
                deployment_args = get_deploy_args(
                    deployment.name,
                    ingress=is_ingress,
                    replica_config=deployment._replica_config,
                    deployment_config=deployment._deployment_config,
                    version=deployment._version or get_random_string(),
                    route_prefix=app.route_prefix if is_ingress else None,
                )

                deployment_args_proto = DeploymentArgs()
                deployment_args_proto.deployment_name = deployment_args[
                    "deployment_name"
                ]
                deployment_args_proto.deployment_config = deployment_args[
                    "deployment_config_proto_bytes"
                ]
                deployment_args_proto.replica_config = deployment_args[
                    "replica_config_proto_bytes"
                ]
                deployment_args_proto.deployer_job_id = deployment_args[
                    "deployer_job_id"
                ]
                if deployment_args["route_prefix"]:
                    deployment_args_proto.route_prefix = deployment_args["route_prefix"]
                deployment_args_proto.ingress = deployment_args["ingress"]

                deployment_args_list.append(deployment_args_proto.SerializeToString())

            application_args_proto = ApplicationArgs()
            application_args_proto.external_scaler_enabled = app.external_scaler_enabled

            name_to_deployment_args_list[app.name] = deployment_args_list
            name_to_application_args[
                app.name
            ] = application_args_proto.SerializeToString()

        # Validate applications before sending to controller
        self._check_ingress_deployments(built_apps)

        ray.get(
            self._controller.deploy_applications.remote(
                name_to_deployment_args_list, name_to_application_args
            )
        )

        handles = []
        for app in built_apps:
            # The deployment state is not guaranteed to be created after
            # deploy_application returns; the application state manager will
            # need another reconcile iteration to create it.
            if wait_for_ingress_deployment_creation:
                self._wait_for_deployment_created(app.ingress_deployment_name, app.name)

            if wait_for_applications_running:
                self._wait_for_application_running(app.name)
                if app.route_prefix is not None:
                    url_part = " at " + self._root_url + app.route_prefix
                else:
                    url_part = ""
                logger.info(f"Application '{app.name}' is ready{url_part}.")

            handles.append(
                self.get_handle(
                    app.ingress_deployment_name, app.name, check_exists=False
                )
            )

        return handles

    @_ensure_connected
    def deploy_apps(
        self,
        config: Union[ServeApplicationSchema, ServeDeploySchema],
        _blocking: bool = False,
    ) -> None:
        """Starts a task on the controller that deploys application(s) from a config.

        Args:
            config: A single-application config (ServeApplicationSchema) or a
                multi-application config (ServeDeploySchema)
            _blocking: Whether to block until the application is running.

        Raises:
            RayTaskError: If the deploy task on the controller fails. This can be
                because a single-app config was deployed after deploying a multi-app
                config, or vice versa.
        """
        ray.get(self._controller.apply_config.remote(config))

        if _blocking:
            timeout_s = 60

            if isinstance(config, ServeDeploySchema):
                app_names = {app.name for app in config.applications}
            else:
                app_names = {config.name}

            start = time.time()
            while time.time() - start < timeout_s:
                statuses = self.list_serve_statuses()
                app_to_status = {
                    status.name: status.app_status.status
                    for status in statuses
                    if status.name in app_names
                }
                if len(app_names) == len(app_to_status) and set(
                    app_to_status.values()
                ) == {ApplicationStatus.RUNNING}:
                    break

                time.sleep(CLIENT_POLLING_INTERVAL_S)
            else:
                raise TimeoutError(
                    f"Serve application isn't running after {timeout_s}s."
                )

            self.wait_for_proxies_serving(wait_for_applications_running=True)

    def _check_ingress_deployments(
        self, built_apps: Sequence[BuiltApplication]
    ) -> None:
        """Check @serve.ingress of deployments across applications.

        Raises: RayServeException if more than one @serve.ingress
            is found among deployments in any single application.
        """
        for app in built_apps:
            num_ingress_deployments = 0
            for deployment in app.deployments:
                if inspect.isclass(deployment.func_or_class) and issubclass(
                    deployment.func_or_class, ASGIAppReplicaWrapper
                ):
                    num_ingress_deployments += 1

            if num_ingress_deployments > 1:
                raise RayServeException(
                    f'Found multiple FastAPI deployments in application "{app.name}".'
                    "Please only include one deployment with @serve.ingress "
                    "in your application to avoid this issue."
                )

    @_ensure_connected
    def delete_apps(self, names: List[str], blocking: bool = True):
        if not names:
            return

        logger.info(f"Deleting app {names}")
        self._controller.delete_apps.remote(names)
        if blocking:
            start = time.time()
            while time.time() - start < 60:
                curr_statuses_bytes = ray.get(
                    self._controller.get_serve_statuses.remote(names)
                )
                all_deleted = True
                for cur_status_bytes in curr_statuses_bytes:
                    cur_status = StatusOverview.from_proto(
                        StatusOverviewProto.FromString(cur_status_bytes)
                    )
                    if cur_status.app_status.status != ApplicationStatus.NOT_STARTED:
                        all_deleted = False
                if all_deleted:
                    return
                time.sleep(CLIENT_POLLING_INTERVAL_S)
            else:
                raise TimeoutError(
                    f"Some of these applications weren't deleted after 60s: {names}"
                )

    @_ensure_connected
    def delete_all_apps(self, blocking: bool = True):
        """Delete all applications"""
        all_apps = []
        for status_bytes in ray.get(self._controller.list_serve_statuses.remote()):
            proto = StatusOverviewProto.FromString(status_bytes)
            status = StatusOverview.from_proto(proto)
            all_apps.append(status.name)
        self.delete_apps(all_apps, blocking)

    @_ensure_connected
    def get_deployment_info(
        self, name: str, app_name: str
    ) -> Tuple[DeploymentInfo, str]:
        deployment_route = DeploymentRoute.FromString(
            ray.get(self._controller.get_deployment_info.remote(name, app_name))
        )
        return (
            DeploymentInfo.from_proto(deployment_route.deployment_info),
            deployment_route.route if deployment_route.route != "" else None,
        )

    @_ensure_connected
    def get_serve_status(self, name: str = SERVE_DEFAULT_APP_NAME) -> StatusOverview:
        proto = StatusOverviewProto.FromString(
            ray.get(self._controller.get_serve_status.remote(name))
        )
        return StatusOverview.from_proto(proto)

    @_ensure_connected
    def list_serve_statuses(self) -> List[StatusOverview]:
        statuses_bytes = ray.get(self._controller.list_serve_statuses.remote())
        return [
            StatusOverview.from_proto(StatusOverviewProto.FromString(status_bytes))
            for status_bytes in statuses_bytes
        ]

    @_ensure_connected
    def get_all_deployment_statuses(self) -> List[DeploymentStatusInfo]:
        statuses_bytes = ray.get(self._controller.get_all_deployment_statuses.remote())
        return [
            DeploymentStatusInfo.from_proto(
                DeploymentStatusInfoProto.FromString(status_bytes)
            )
            for status_bytes in statuses_bytes
        ]

    @_ensure_connected
    def get_serve_details(self) -> Dict:
        return ray.get(self._controller.get_serve_instance_details.remote())

    @_ensure_connected
    def get_handle(
        self,
        deployment_name: str,
        app_name: Optional[str] = SERVE_DEFAULT_APP_NAME,
        check_exists: bool = True,
    ) -> DeploymentHandle:
        """Construct a handle for the specified deployment.

        Args:
            deployment_name: Deployment name.
            app_name: Application name.
            check_exists: If False, then Serve won't check the deployment
                is registered. True by default.

        Returns:
            DeploymentHandle
        """
        deployment_id = DeploymentID(name=deployment_name, app_name=app_name)
        cache_key = (deployment_name, app_name, check_exists)
        if cache_key in self.handle_cache:
            return self.handle_cache[cache_key]

        if check_exists:
            all_deployments = ray.get(self._controller.list_deployment_ids.remote())
            if deployment_id not in all_deployments:
                raise KeyError(f"{deployment_id} does not exist.")

        handle = DeploymentHandle(deployment_name, app_name)
        self.handle_cache[cache_key] = handle
        if cache_key in self._evicted_handle_keys:
            logger.warning(
                "You just got a ServeHandle that was evicted from internal "
                "cache. This means you are getting too many ServeHandles in "
                "the same process, this will bring down Serve's performance. "
                "Please post a github issue at "
                "https://github.com/ray-project/ray/issues to let the Serve "
                "team to find workaround for your use case."
            )

        if len(self.handle_cache) > MAX_CACHED_HANDLES:
            # Perform random eviction to keep the handle cache from growing
            # infinitely. We used use WeakValueDictionary but hit
            # https://github.com/ray-project/ray/issues/18980.
            evict_key = random.choice(list(self.handle_cache.keys()))
            self._evicted_handle_keys.add(evict_key)
            self.handle_cache.pop(evict_key)

        return handle

    @_ensure_connected
    def record_request_routing_info(self, info: RequestRoutingInfo):
        """Record replica routing information for a replica.

        Args:
            info: RequestRoutingInfo including deployment name, replica tag,
                multiplex model ids, and routing stats.
        """
        self._controller.record_request_routing_info.remote(info)

    @_ensure_connected
    def update_global_logging_config(self, logging_config: LoggingConfig):
        """Reconfigure the logging config for the controller & proxies."""
        self._controller.reconfigure_global_logging_config.remote(logging_config)
