import abc
import functools
import inspect
import logging
import os
import socket
from typing import (
    Any,
    Callable,
    ContextManager,
    Dict,
    List,
    Optional,
    Tuple,
    TypeVar,
    Union,
)

import ray
from ray._common.network_utils import find_free_port, is_ipv6
from ray.actor import ActorHandle
from ray.air._internal.util import (
    StartTraceback,
    StartTracebackWithWorkerRank,
)
from ray.exceptions import RayActorError
from ray.types import ObjectRef

T = TypeVar("T")

logger = logging.getLogger(__name__)


def check_for_failure(
    remote_values: List[ObjectRef],
) -> Tuple[bool, Optional[Exception]]:
    """Check for actor failure when retrieving the remote values.

    Args:
        remote_values: List of object references from Ray actor methods.

    Returns:
        A tuple of (bool, Exception). The bool is
        True if evaluating all object references is successful, False otherwise.
    """
    unfinished = remote_values.copy()

    while len(unfinished) > 0:
        finished, unfinished = ray.wait(unfinished)

        # If a failure occurs the ObjectRef will be marked as finished.
        # Calling ray.get will expose the failure as a RayActorError.
        for object_ref in finished:
            # Everything in finished has either failed or completed
            # successfully.
            try:
                ray.get(object_ref)
            except RayActorError as exc:
                failed_actor_rank = remote_values.index(object_ref)
                logger.info(f"Worker {failed_actor_rank} has failed.")
                return False, exc
            except Exception as exc:
                # Other (e.g. training) errors should be directly raised
                failed_worker_rank = remote_values.index(object_ref)
                raise StartTracebackWithWorkerRank(
                    worker_rank=failed_worker_rank
                ) from exc

    return True, None


def get_address_and_port() -> Tuple[str, int]:
    """Returns the IP address and a free port on this node."""
    addr = ray.util.get_node_ip_address()
    port = find_free_port(socket.AF_INET6 if is_ipv6(addr) else socket.AF_INET)
    return addr, port


def update_env_vars(env_vars: Dict[str, Any]):
    """Updates the environment variables on this worker process.

    Args:
        env_vars: Environment variables to set.
    """
    sanitized = {k: str(v) for k, v in env_vars.items()}
    os.environ.update(sanitized)


def count_required_parameters(fn: Callable) -> int:
    """Counts the number of required parameters of a function.

    NOTE: *args counts as 1 required parameter.

    Examples
    --------

    >>> def fn(a, b, /, c, *args, d=1, e=2, **kwargs):
    ...    pass
    >>> count_required_parameters(fn)
    4

    >>> fn = lambda: 1
    >>> count_required_parameters(fn)
    0

    >>> def fn(config, a, b=1, c=2):
    ...     pass
    >>> from functools import partial
    >>> count_required_parameters(partial(fn, a=0))
    1
    """
    params = inspect.signature(fn).parameters.values()

    positional_param_kinds = {
        inspect.Parameter.POSITIONAL_ONLY,
        inspect.Parameter.POSITIONAL_OR_KEYWORD,
        inspect.Parameter.VAR_POSITIONAL,
    }
    return len(
        [
            p
            for p in params
            if p.default == inspect.Parameter.empty and p.kind in positional_param_kinds
        ]
    )


def construct_train_func(
    train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
    config: Optional[Dict[str, Any]],
    train_func_context: ContextManager,
    fn_arg_name: Optional[str] = "train_func",
    discard_returns: bool = False,
) -> Callable[[], T]:
    """Validates and constructs the training function to execute.
    Args:
        train_func: The training function to execute.
            This can either take in no arguments or a ``config`` dict.
        config (Optional[Dict]): Configurations to pass into
            ``train_func``. If None then an empty Dict will be created.
        train_func_context: Context manager for user's `train_func`, which executes
            backend-specific logic before and after the training function.
        fn_arg_name (Optional[str]): The name of training function to use for error
            messages.
        discard_returns: Whether to discard any returns from train_func or not.
    Returns:
        A valid training function.
    Raises:
        ValueError: if the input ``train_func`` is invalid.
    """
    num_required_params = count_required_parameters(train_func)

    if discard_returns:
        # Discard any returns from the function so that
        # BackendExecutor doesn't try to deserialize them.
        # Those returns are inaccesible with AIR anyway.
        @functools.wraps(train_func)
        def discard_return_wrapper(*args, **kwargs):
            try:
                train_func(*args, **kwargs)
            except Exception as e:
                raise StartTraceback from e

        wrapped_train_func = discard_return_wrapper
    else:
        wrapped_train_func = train_func

    if num_required_params > 1:
        err_msg = (
            f"{fn_arg_name} should take in 0 or 1 required arguments, but it accepts "
            f"{num_required_params} required arguments instead."
        )
        raise ValueError(err_msg)
    elif num_required_params == 1:
        config = {} if config is None else config

        @functools.wraps(wrapped_train_func)
        def train_fn():
            try:
                with train_func_context():
                    return wrapped_train_func(config)
            except Exception as e:
                raise StartTraceback from e

    else:  # num_params == 0

        @functools.wraps(wrapped_train_func)
        def train_fn():
            try:
                with train_func_context():
                    return wrapped_train_func()
            except Exception as e:
                raise StartTraceback from e

    return train_fn


class Singleton(abc.ABCMeta):
    """Singleton Abstract Base Class

    https://stackoverflow.com/questions/33364070/implementing
    -singleton-as-metaclass-but-for-abstract-classes
    """

    _instances = {}

    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
        return cls._instances[cls]


class ActorWrapper:
    """Wraps an actor to provide same API as using the base class directly."""

    def __init__(self, actor: ActorHandle):
        self.actor = actor

    def __getattr__(self, item):
        # The below will fail if trying to access an attribute (not a method) from the
        # actor.
        actor_method = getattr(self.actor, item)
        return lambda *args, **kwargs: ray.get(actor_method.remote(*args, **kwargs))
