"""Common pre-checks for all RLlib experiments."""
import logging
from typing import TYPE_CHECKING, Set

import gymnasium as gym
import numpy as np
import tree  # pip install dm_tree

from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.error import ERR_MSG_OLD_GYM_API, UnsupportedSpaceException
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.util import log_once

if TYPE_CHECKING:
    from ray.rllib.env import MultiAgentEnv

logger = logging.getLogger(__name__)


@DeveloperAPI
def check_multiagent_environments(env: "MultiAgentEnv") -> None:
    """Checking for common errors in RLlib MultiAgentEnvs.

    Args:
        env: The env to be checked.
    """
    from ray.rllib.env import MultiAgentEnv

    if not isinstance(env, MultiAgentEnv):
        raise ValueError("The passed env is not a MultiAgentEnv.")
    elif not (
        hasattr(env, "observation_space")
        and hasattr(env, "action_space")
        and hasattr(env, "_agent_ids")
    ):
        if log_once("ma_env_super_ctor_called"):
            logger.warning(
                f"Your MultiAgentEnv {env} does not have some or all of the needed "
                "base-class attributes! Make sure you call `super().__init__()` from "
                "within your MutiAgentEnv's constructor. "
                "This will raise an error in the future."
            )
        return

    try:
        obs_and_infos = env.reset(seed=42, options={})
    except Exception as e:
        raise ValueError(
            ERR_MSG_OLD_GYM_API.format(
                env, "In particular, the `reset()` method seems to be faulty."
            )
        ) from e
    reset_obs, reset_infos = obs_and_infos

    _check_if_element_multi_agent_dict(env, reset_obs, "reset()")

    sampled_action = {
        aid: env.get_action_space(aid).sample() for aid in reset_obs.keys()
    }
    _check_if_element_multi_agent_dict(
        env, sampled_action, "get_action_space(agent_id=..).sample()"
    )

    try:
        results = env.step(sampled_action)
    except Exception as e:
        raise ValueError(
            ERR_MSG_OLD_GYM_API.format(
                env, "In particular, the `step()` method seems to be faulty."
            )
        ) from e
    next_obs, reward, done, truncated, info = results

    _check_if_element_multi_agent_dict(env, next_obs, "step, next_obs")
    _check_if_element_multi_agent_dict(env, reward, "step, reward")
    _check_if_element_multi_agent_dict(env, done, "step, done")
    _check_if_element_multi_agent_dict(env, truncated, "step, truncated")
    _check_if_element_multi_agent_dict(env, info, "step, info", allow_common=True)
    _check_reward({"dummy_env_id": reward}, base_env=True, agent_ids=env.agents)
    _check_done_and_truncated(
        {"dummy_env_id": done},
        {"dummy_env_id": truncated},
        base_env=True,
        agent_ids=env.agents,
    )
    _check_info({"dummy_env_id": info}, base_env=True, agent_ids=env.agents)


def _check_reward(reward, base_env=False, agent_ids=None):
    if base_env:
        for _, multi_agent_dict in reward.items():
            for agent_id, rew in multi_agent_dict.items():
                if not (
                    np.isreal(rew)
                    and not isinstance(rew, bool)
                    and (
                        np.isscalar(rew)
                        or (isinstance(rew, np.ndarray) and rew.shape == ())
                    )
                ):
                    error = (
                        "Your step function must return rewards that are"
                        f" integer or float. reward: {rew}. Instead it was a "
                        f"{type(rew)}"
                    )
                    raise ValueError(error)
                if not (agent_id in agent_ids or agent_id == "__all__"):
                    error = (
                        f"Your reward dictionary must have agent ids that belong to "
                        f"the environment. AgentIDs received from "
                        f"env.agents are: {agent_ids}"
                    )
                    raise ValueError(error)
    elif not (
        np.isreal(reward)
        and not isinstance(reward, bool)
        and (
            np.isscalar(reward)
            or (isinstance(reward, np.ndarray) and reward.shape == ())
        )
    ):
        error = (
            "Your step function must return a reward that is integer or float. "
            "Instead it was a {}".format(type(reward))
        )
        raise ValueError(error)


def _check_done_and_truncated(done, truncated, base_env=False, agent_ids=None):
    for what in ["done", "truncated"]:
        data = done if what == "done" else truncated
        if base_env:
            for _, multi_agent_dict in data.items():
                for agent_id, done_ in multi_agent_dict.items():
                    if not isinstance(done_, (bool, np.bool_)):
                        raise ValueError(
                            f"Your step function must return `{what}s` that are "
                            f"boolean. But instead was a {type(data)}"
                        )
                    if not (agent_id in agent_ids or agent_id == "__all__"):
                        error = (
                            f"Your `{what}s` dictionary must have agent ids that "
                            f"belong to the environment. AgentIDs received from "
                            f"env.agents are: {agent_ids}"
                        )
                        raise ValueError(error)
        elif not isinstance(data, (bool, np.bool_)):
            error = (
                f"Your step function must return a `{what}` that is a boolean. But "
                f"instead was a {type(data)}"
            )
            raise ValueError(error)


def _check_info(info, base_env=False, agent_ids=None):
    if base_env:
        for _, multi_agent_dict in info.items():
            for agent_id, inf in multi_agent_dict.items():
                if not isinstance(inf, dict):
                    raise ValueError(
                        "Your step function must return infos that are a dict. "
                        f"instead was a {type(inf)}: element: {inf}"
                    )
                if not (
                    agent_id in agent_ids
                    or agent_id == "__all__"
                    or agent_id == "__common__"
                ):
                    error = (
                        f"Your dones dictionary must have agent ids that belong to "
                        f"the environment. AgentIDs received from "
                        f"env.agents are: {agent_ids}"
                    )
                    raise ValueError(error)
    elif not isinstance(info, dict):
        error = (
            "Your step function must return a info that "
            f"is a dict. element type: {type(info)}. element: {info}"
        )
        raise ValueError(error)


def _not_contained_error(func_name, _type):
    _error = (
        f"The {_type} collected from {func_name} was not contained within"
        f" your env's {_type} space. Its possible that there was a type"
        f"mismatch (for example {_type}s of np.float32 and a space of"
        f"np.float64 {_type}s), or that one of the sub-{_type}s was"
        f"out of bounds"
    )
    return _error


def _check_if_element_multi_agent_dict(
    env,
    element,
    function_string,
    base_env=False,
    allow_common=False,
):
    if not isinstance(element, dict):
        if base_env:
            error = (
                f"The element returned by {function_string} contains values "
                f"that are not MultiAgentDicts. Instead, they are of "
                f"type: {type(element)}"
            )
        else:
            error = (
                f"The element returned by {function_string} is not a "
                f"MultiAgentDict. Instead, it is of type: "
                f" {type(element)}"
            )
        raise ValueError(error)
    agent_ids: Set = set(env.agents)
    agent_ids.add("__all__")
    if allow_common:
        agent_ids.add("__common__")

    if not all(k in agent_ids for k in element):
        if base_env:
            error = (
                f"The element returned by {function_string} has agent_ids"
                f" that are not the names of the agents in the env."
                f"agent_ids in this\nMultiEnvDict:"
                f" {list(element.keys())}\nAgentIDs in this env: "
                f"{env.agents}"
            )
        else:
            error = (
                f"The element returned by {function_string} has agent_ids"
                f" that are not the names of the agents in the env. "
                f"\nAgentIDs in this MultiAgentDict: "
                f"{list(element.keys())}\nAgentIDs in this env: "
                f"{env.agents}. You likely need to add the attribute `agents` to your "
                f"env, which is a list containing the IDs of agents currently in your "
                f"env/episode, as well as, `possible_agents`, which is a list of all "
                f"possible agents that could ever show up in your env."
            )
        raise ValueError(error)


def _find_offending_sub_space(space, value):
    """Returns error, value, and space when offending `space.contains(value)` fails.

    Returns only the offending sub-value/sub-space in case `space` is a complex Tuple
    or Dict space.

    Args:
        space: The gym.Space to check.
        value: The actual (numpy) value to check for matching `space`.

    Returns:
        Tuple consisting of 1) key-sequence of the offending sub-space or the empty
        string if `space` is not complex (Tuple or Dict), 2) the offending sub-space,
        3) the offending sub-space's dtype, 4) the offending sub-value, 5) the offending
        sub-value's dtype.

    .. testcode::
        :skipif: True

        path, space, space_dtype, value, value_dtype = _find_offending_sub_space(
            gym.spaces.Dict({
           -2.0, 1.5, (2, ), np.int8), np.array([-1.5, 3.0])
        )

    """
    if not isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)):
        return None, space, space.dtype, value, _get_type(value)

    structured_space = get_base_struct_from_space(space)

    def map_fn(p, s, v):
        if not s.contains(v):
            raise UnsupportedSpaceException((p, s, v))

    try:
        tree.map_structure_with_path(map_fn, structured_space, value)
    except UnsupportedSpaceException as e:
        space, value = e.args[0][1], e.args[0][2]
        return "->".join(e.args[0][0]), space, space.dtype, value, _get_type(value)

    # This is actually an error.
    return None, None, None, None, None


def _get_type(var):
    return var.dtype if hasattr(var, "dtype") else type(var)
