import abc
import asyncio
import datetime
import functools
import importlib
import json
import logging
import os
import pkgutil
from abc import ABCMeta, abstractmethod
from base64 import b64decode
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from enum import IntEnum
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from ray._common.utils import binary_to_hex

if TYPE_CHECKING:
    from ray.core.generated.node_manager_pb2 import GetNodeStatsReply

from packaging.version import Version

import ray
import ray._private.protobuf_compat
import ray._private.ray_constants as ray_constants
import ray._private.services as services
import ray.experimental.internal_kv as internal_kv
from ray._common.network_utils import parse_address
from ray._common.utils import get_or_create_event_loop
from ray._private.gcs_utils import GcsChannel
from ray._private.utils import (
    get_dashboard_dependency_error,
    split_address,
)
from ray._raylet import GcsClient

try:
    create_task = asyncio.create_task
except AttributeError:
    create_task = asyncio.ensure_future

logger = logging.getLogger(__name__)


class HTTPStatusCode(IntEnum):
    # 2xx Success
    OK = 200

    # 4xx Client Errors
    BAD_REQUEST = 400
    NOT_FOUND = 404
    TOO_MANY_REQUESTS = 429

    # 5xx Server Errors
    INTERNAL_ERROR = 500


class FrontendNotFoundError(OSError):
    pass


class DashboardAgentModule(abc.ABC):
    def __init__(self, dashboard_agent):
        """
        Initialize current module when DashboardAgent loading modules.
        :param dashboard_agent: The DashboardAgent instance.
        """
        self._dashboard_agent = dashboard_agent
        self.session_name = dashboard_agent.session_name

    @abc.abstractmethod
    async def run(self, server):
        """
        Run the module in an asyncio loop. An agent module can provide
        servicers to the server.
        :param server: Asyncio GRPC server, or None if ray is minimal.
        """

    @staticmethod
    @abc.abstractclassmethod
    def is_minimal_module():
        """
        Return True if the module is minimal, meaning it
        should work with `pip install ray` that doesn't requires additional
        dependencies.
        """

    @property
    def gcs_address(self):
        return self._dashboard_agent.gcs_address


@dataclass
class DashboardHeadModuleConfig:
    minimal: bool
    cluster_id_hex: str
    session_name: str
    gcs_address: str
    log_dir: str
    temp_dir: str
    session_dir: str
    ip: str
    http_host: str
    http_port: int


class DashboardHeadModule(abc.ABC):
    def __init__(self, config: DashboardHeadModuleConfig):
        """
        Initialize current module when DashboardHead loading modules.
        :param config: The DashboardHeadModuleConfig instance.
        """
        self._config = config
        self._gcs_client = None
        self._aiogrpc_gcs_channel = None  # lazy init
        self._http_session = None  # lazy init

    @property
    def minimal(self):
        return self._config.minimal

    @property
    def session_name(self):
        return self._config.session_name

    @property
    def gcs_address(self):
        return self._config.gcs_address

    @property
    def log_dir(self):
        return self._config.log_dir

    @property
    def temp_dir(self):
        return self._config.temp_dir

    @property
    def session_dir(self):
        return self._config.session_dir

    @property
    def ip(self):
        return self._config.ip

    @property
    def http_host(self):
        return self._config.http_host

    @property
    def http_port(self):
        return self._config.http_port

    @property
    def http_session(self):
        assert not self._config.minimal, "http_session accessed in minimal Ray."
        import aiohttp

        if self._http_session is not None:
            return self._http_session
        # Create a http session for all modules.
        # aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore
        if Version(aiohttp.__version__) < Version("4.0.0"):
            self._http_session = aiohttp.ClientSession(loop=get_or_create_event_loop())
        else:
            self._http_session = aiohttp.ClientSession()
        return self._http_session

    @property
    def gcs_client(self):
        if self._gcs_client is None:
            self._gcs_client = GcsClient(
                address=self._config.gcs_address,
                cluster_id=self._config.cluster_id_hex,
            )
            if not internal_kv._internal_kv_initialized():
                internal_kv._initialize_internal_kv(self._gcs_client)
        return self._gcs_client

    @property
    def aiogrpc_gcs_channel(self):
        # TODO(ryw): once we removed the old gcs client, also remove this.
        if self._config.minimal:
            return None
        if self._aiogrpc_gcs_channel is None:
            gcs_channel = GcsChannel(gcs_address=self._config.gcs_address, aio=True)
            gcs_channel.connect()
            self._aiogrpc_gcs_channel = gcs_channel.channel()
        return self._aiogrpc_gcs_channel

    @abc.abstractmethod
    async def run(self):
        """
        Run the module in an asyncio loop. A head module can provide
        servicers to the server.
        """

    @staticmethod
    @abc.abstractclassmethod
    def is_minimal_module():
        """
        Return True if the module is minimal, meaning it
        should work with `pip install ray` that doesn't requires additional
        dependencies.
        """


class RateLimitedModule(abc.ABC):
    """Simple rate limiter

    Inheriting from this class and decorate any class methods will
    apply simple rate limit.
    It will limit the maximal number of concurrent invocations of **all** the
    methods decorated.

    The below Example class will only allow 10 concurrent calls to A() and B()

    E.g.:

        class Example(RateLimitedModule):
            def __init__(self):
                super().__init__(max_num_call=10)

            @RateLimitedModule.enforce_max_concurrent_calls
            async def A():
                ...

            @RateLimitedModule.enforce_max_concurrent_calls
            async def B():
                ...

            async def limit_handler_(self):
                raise RuntimeError("rate limited reached!")

    """

    def __init__(self, max_num_call: int, logger: Optional[logging.Logger] = None):
        """
        Args:
            max_num_call: Maximal number of concurrent invocations of all decorated
                functions in the instance.
                Setting to -1 will disable rate limiting.

            logger: Logger
        """
        self.max_num_call_ = max_num_call
        self.num_call_ = 0
        self.logger_ = logger

    @staticmethod
    def enforce_max_concurrent_calls(func):
        """Decorator to enforce max number of invocations of the decorated func

        NOTE: This should be used as the innermost decorator if there are multiple
        ones.

        E.g., when decorating functions already with @routes.get(...), this must be
        added below then the routes decorators:
            ```
            @routes.get('/')
            @RateLimitedModule.enforce_max_concurrent_calls
            async def fn(self):
                ...

            ```
        """

        @functools.wraps(func)
        async def async_wrapper(self, *args, **kwargs):
            if self.max_num_call_ >= 0 and self.num_call_ >= self.max_num_call_:
                if self.logger_:
                    self.logger_.warning(
                        f"Max concurrent requests reached={self.max_num_call_}"
                    )
                return await self.limit_handler_()
            self.num_call_ += 1
            try:
                ret = await func(self, *args, **kwargs)
            finally:
                self.num_call_ -= 1
            return ret

        # Returning closure here to avoid passing 'self' to the
        # 'enforce_max_concurrent_calls' decorator.
        return async_wrapper

    @abstractmethod
    async def limit_handler_(self):
        """Handler that is invoked when max number of concurrent calls reached"""


def dashboard_module(enable):
    """A decorator for dashboard module."""

    def _cls_wrapper(cls):
        cls.__ray_dashboard_module_enable__ = enable
        return cls

    return _cls_wrapper


def get_all_modules(module_type):
    """
    Get all importable modules that are subclass of a given module type.
    """
    logger.info(f"Get all modules by type: {module_type.__name__}")
    import ray.dashboard.modules

    should_only_load_minimal_modules = get_dashboard_dependency_error() is not None

    for module_loader, name, ispkg in pkgutil.walk_packages(
        ray.dashboard.modules.__path__, ray.dashboard.modules.__name__ + "."
    ):
        try:
            importlib.import_module(name)
        except ModuleNotFoundError as e:
            logger.info(
                f"Module {name} cannot be loaded because "
                "we cannot import all dependencies. Install this module using "
                "`pip install 'ray[default]'` for the full "
                f"dashboard functionality. Error: {e}"
            )
            if not should_only_load_minimal_modules:
                logger.info(
                    "Although `pip install 'ray[default]'` is downloaded, "
                    "module couldn't be imported`"
                )
                raise e

    imported_modules = []
    # module_type.__subclasses__() should contain modules that
    # we could successfully import.
    for m in module_type.__subclasses__():
        if not getattr(m, "__ray_dashboard_module_enable__", True):
            continue
        if should_only_load_minimal_modules and not m.is_minimal_module():
            continue
        imported_modules.append(m)
    logger.info(f"Available modules: {imported_modules}")
    return imported_modules


def to_posix_time(dt):
    return (dt - datetime.datetime(1970, 1, 1)).total_seconds()


def address_tuple(address):
    if isinstance(address, tuple):
        return address
    ip, port = parse_address(address)
    return ip, int(port)


def node_stats_to_dict(
    message: "GetNodeStatsReply",
) -> Optional[Dict[str, List[Dict[str, Any]]]]:
    decode_keys = {
        "actorId",
        "jobId",
        "taskId",
        "parentTaskId",
        "sourceActorId",
        "callerId",
        "nodeId",
        "workerId",
        "placementGroupId",
    }
    core_workers_stats = message.core_workers_stats
    result = message_to_dict(message, decode_keys)
    result["coreWorkersStats"] = [
        message_to_dict(m, decode_keys, always_print_fields_with_no_presence=True)
        for m in core_workers_stats
    ]
    return result


class CustomEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, bytes):
            return binary_to_hex(obj)
        if isinstance(obj, Immutable):
            return obj.mutable()
        # Let the base class default method raise the TypeError
        return json.JSONEncoder.default(self, obj)


def to_camel_case(snake_str):
    """Convert a snake str to camel case."""
    components = snake_str.split("_")
    # We capitalize the first letter of each component except the first one
    # with the 'title' method and join them together.
    return components[0] + "".join(x.title() for x in components[1:])


def to_google_style(d):
    """Recursive convert all keys in dict to google style."""
    new_dict = {}

    for k, v in d.items():
        if isinstance(v, dict):
            new_dict[to_camel_case(k)] = to_google_style(v)
        elif isinstance(v, list):
            new_list = []
            for i in v:
                if isinstance(i, dict):
                    new_list.append(to_google_style(i))
                else:
                    new_list.append(i)
            new_dict[to_camel_case(k)] = new_list
        else:
            new_dict[to_camel_case(k)] = v
    return new_dict


def message_to_dict(message, decode_keys=None, **kwargs):
    """Convert protobuf message to Python dict."""

    def _decode_keys(d):
        for k, v in d.items():
            if isinstance(v, dict):
                d[k] = _decode_keys(v)
            if isinstance(v, list):
                new_list = []
                for i in v:
                    if isinstance(i, dict):
                        new_list.append(_decode_keys(i))
                    else:
                        new_list.append(i)
                d[k] = new_list
            else:
                if k in decode_keys:
                    d[k] = binary_to_hex(b64decode(v))
                else:
                    d[k] = v
        return d

    d = ray._private.protobuf_compat.message_to_dict(
        message, use_integers_for_enums=False, **kwargs
    )
    if decode_keys:
        return _decode_keys(d)
    else:
        return d


class Bunch(dict):
    """A dict with attribute-access."""

    def __getattr__(self, key):
        try:
            return self.__getitem__(key)
        except KeyError:
            raise AttributeError(key)

    def __setattr__(self, key, value):
        self.__setitem__(key, value)


"""
https://docs.python.org/3/library/json.html?highlight=json#json.JSONEncoder
    +-------------------+---------------+
    | Python            | JSON          |
    +===================+===============+
    | dict              | object        |
    +-------------------+---------------+
    | list, tuple       | array         |
    +-------------------+---------------+
    | str               | string        |
    +-------------------+---------------+
    | int, float        | number        |
    +-------------------+---------------+
    | True              | true          |
    +-------------------+---------------+
    | False             | false         |
    +-------------------+---------------+
    | None              | null          |
    +-------------------+---------------+
"""
_json_compatible_types = {dict, list, tuple, str, int, float, bool, type(None), bytes}


def is_immutable(self):
    raise TypeError("%r objects are immutable" % self.__class__.__name__)


def make_immutable(value, strict=True):
    value_type = type(value)
    if value_type is dict:
        return ImmutableDict(value)
    if value_type is list:
        return ImmutableList(value)
    if strict:
        if value_type not in _json_compatible_types:
            raise TypeError("Type {} can't be immutable.".format(value_type))
    return value


class Immutable(metaclass=ABCMeta):
    @abstractmethod
    def mutable(self):
        pass


class ImmutableList(Immutable, Sequence):
    """Makes a :class:`list` immutable."""

    __slots__ = ("_list", "_proxy")

    def __init__(self, list_value):
        if type(list_value) not in (list, ImmutableList):
            raise TypeError(f"{type(list_value)} object is not a list.")
        if isinstance(list_value, ImmutableList):
            list_value = list_value.mutable()
        self._list = list_value
        self._proxy = [None] * len(list_value)

    def __reduce_ex__(self, protocol):
        return type(self), (self._list,)

    def mutable(self):
        return self._list

    def __eq__(self, other):
        if isinstance(other, ImmutableList):
            other = other.mutable()
        return list.__eq__(self._list, other)

    def __ne__(self, other):
        if isinstance(other, ImmutableList):
            other = other.mutable()
        return list.__ne__(self._list, other)

    def __contains__(self, item):
        if isinstance(item, Immutable):
            item = item.mutable()
        return list.__contains__(self._list, item)

    def __getitem__(self, item):
        proxy = self._proxy[item]
        if proxy is None:
            proxy = self._proxy[item] = make_immutable(self._list[item])
        return proxy

    def __len__(self):
        return len(self._list)

    def __repr__(self):
        return "%s(%s)" % (self.__class__.__name__, list.__repr__(self._list))


class ImmutableDict(Immutable, Mapping):
    """Makes a :class:`dict` immutable."""

    __slots__ = ("_dict", "_proxy")

    def __init__(self, dict_value):
        if type(dict_value) not in (dict, ImmutableDict):
            raise TypeError(f"{type(dict_value)} object is not a dict.")
        if isinstance(dict_value, ImmutableDict):
            dict_value = dict_value.mutable()
        self._dict = dict_value
        self._proxy = {}

    def __reduce_ex__(self, protocol):
        return type(self), (self._dict,)

    def mutable(self):
        return self._dict

    def get(self, key, default=None):
        try:
            return self[key]
        except KeyError:
            return make_immutable(default)

    def __eq__(self, other):
        if isinstance(other, ImmutableDict):
            other = other.mutable()
        return dict.__eq__(self._dict, other)

    def __ne__(self, other):
        if isinstance(other, ImmutableDict):
            other = other.mutable()
        return dict.__ne__(self._dict, other)

    def __contains__(self, item):
        if isinstance(item, Immutable):
            item = item.mutable()
        return dict.__contains__(self._dict, item)

    def __getitem__(self, item):
        proxy = self._proxy.get(item, None)
        if proxy is None:
            proxy = self._proxy[item] = make_immutable(self._dict[item])
        return proxy

    def __len__(self) -> int:
        return len(self._dict)

    def __iter__(self):
        if len(self._proxy) != len(self._dict):
            for key in self._dict.keys() - self._proxy.keys():
                self._proxy[key] = make_immutable(self._dict[key])
        return iter(self._proxy)

    def __repr__(self):
        return "%s(%s)" % (self.__class__.__name__, dict.__repr__(self._dict))


# Register immutable types.
for immutable_type in Immutable.__subclasses__():
    _json_compatible_types.add(immutable_type)


def async_loop_forever(interval_seconds, cancellable=False):
    def _wrapper(coro):
        @functools.wraps(coro)
        async def _looper(*args, **kwargs):
            while True:
                try:
                    await coro(*args, **kwargs)
                except asyncio.CancelledError as ex:
                    if cancellable:
                        logger.info(
                            f"An async loop forever coroutine " f"is cancelled {coro}."
                        )
                        raise ex
                    else:
                        logger.exception(
                            f"Can not cancel the async loop "
                            f"forever coroutine {coro}."
                        )
                except Exception:
                    logger.exception(f"Error looping coroutine {coro}.")
                await asyncio.sleep(interval_seconds)

        return _looper

    return _wrapper


def ray_client_address_to_api_server_url(address: str):
    """Convert a Ray Client address of a running Ray cluster to its API server URL.

    Args:
        address: The Ray Client address, e.g. "ray://my-cluster".

    Returns:
        str: The API server URL of the cluster, e.g. "http://<head-node-ip>:8265".
    """
    with ray.init(address=address) as client_context:
        dashboard_url = client_context.dashboard_url

    return f"http://{dashboard_url}"


def ray_address_to_api_server_url(address: Optional[str]) -> str:
    """Parse a Ray cluster address into API server URL.

    When an address is provided, it will be used to query GCS for
    API server address from GCS, so a Ray cluster must be running.

    When an address is not provided, it will first try to auto-detect
    a running Ray instance, or look for local GCS process.

    Args:
        address: Ray cluster bootstrap address or Ray Client address.
            Could also be `auto`.

    Returns:
        API server HTTP URL.
    """

    address = services.canonicalize_bootstrap_address_or_die(address)
    gcs_client = GcsClient(address=address)

    ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
    api_server_url = ray._private.utils.internal_kv_get_with_retry(
        gcs_client,
        ray_constants.DASHBOARD_ADDRESS,
        namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
        num_retries=20,
    )

    if api_server_url is None:
        raise ValueError(
            (
                "Couldn't obtain the API server address from GCS. It is likely that "
                "the GCS server is down. Check gcs_server.[out | err] to see if it is "
                "still alive."
            )
        )
    api_server_url = f"http://{api_server_url.decode()}"
    return api_server_url


def get_address_for_submission_client(address: Optional[str]) -> str:
    """Get Ray API server address from Ray bootstrap or Client address.

    If None, it will try to auto-detect a running Ray instance, or look
    for local GCS process.

    `address` is always overridden by the RAY_ADDRESS environment
    variable, just like the `address` argument in `ray.init()`.

    Args:
        address: Ray cluster bootstrap address or Ray Client address.
            Could also be "auto".

    Returns:
        API server HTTP URL, e.g. "http://<head-node-ip>:8265".
    """
    if api_server_address := os.environ.get(
        ray_constants.RAY_API_SERVER_ADDRESS_ENVIRONMENT_VARIABLE
    ):
        address = api_server_address
        logger.debug(f"Using RAY_API_SERVER_ADDRESS={address}")
    # Fall back to RAY_ADDRESS if RAY_API_SERVER_ADDRESS not set
    elif ray_address := os.environ.get(ray_constants.RAY_ADDRESS_ENVIRONMENT_VARIABLE):
        address = ray_address
        logger.debug(f"Using RAY_ADDRESS={address}")

    if address and "://" in address:
        module_string, _ = split_address(address)
        if module_string == "ray":
            logger.debug(
                f"Retrieving API server address from Ray Client address {address}..."
            )
            address = ray_client_address_to_api_server_url(address)
    else:
        # User specified a non-Ray-Client Ray cluster address.
        address = ray_address_to_api_server_url(address)
    logger.debug(f"Using API server address {address}.")
    return address


def compose_state_message(
    death_reason: Optional[str], death_reason_message: Optional[str]
) -> Optional[str]:
    """Compose node state message based on death information.

    Args:
        death_reason: The reason of node death.
            This is a string representation of `gcs_pb2.NodeDeathInfo.Reason`.
        death_reason_message: The message of node death.
            This corresponds to `gcs_pb2.NodeDeathInfo.ReasonMessage`.
    """
    if death_reason == "EXPECTED_TERMINATION":
        state_message = "Expected termination"
    elif death_reason == "UNEXPECTED_TERMINATION":
        state_message = "Unexpected termination"
    elif death_reason == "AUTOSCALER_DRAIN_PREEMPTED":
        state_message = "Terminated due to preemption"
    elif death_reason == "AUTOSCALER_DRAIN_IDLE":
        state_message = "Terminated due to idle (no Ray activity)"
    else:
        state_message = None

    if death_reason_message:
        if state_message:
            state_message += f": {death_reason_message}"
        else:
            state_message = death_reason_message
    return state_message


def close_logger_file_descriptor(logger_instance):
    for handler in logger_instance.handlers:
        handler.close()
