"""Registry of algorithm names for tune.Tuner(trainable=[..])."""

import importlib
import re


def _import_appo():
    import ray.rllib.algorithms.appo as appo

    return appo.APPO, appo.APPO.get_default_config()


def _import_bc():
    import ray.rllib.algorithms.bc as bc

    return bc.BC, bc.BC.get_default_config()


def _import_cql():
    import ray.rllib.algorithms.cql as cql

    return cql.CQL, cql.CQL.get_default_config()


def _import_dqn():
    import ray.rllib.algorithms.dqn as dqn

    return dqn.DQN, dqn.DQN.get_default_config()


def _import_dreamerv3():
    import ray.rllib.algorithms.dreamerv3 as dreamerv3

    return dreamerv3.DreamerV3, dreamerv3.DreamerV3.get_default_config()


def _import_impala():
    import ray.rllib.algorithms.impala as impala

    return impala.IMPALA, impala.IMPALA.get_default_config()


def _import_iql():
    import ray.rllib.algorithms.iql as iql

    return iql.IQL, iql.IQL.get_default_config()


def _import_marwil():
    import ray.rllib.algorithms.marwil as marwil

    return marwil.MARWIL, marwil.MARWIL.get_default_config()


def _import_ppo():
    import ray.rllib.algorithms.ppo as ppo

    return ppo.PPO, ppo.PPO.get_default_config()


def _import_sac():
    import ray.rllib.algorithms.sac as sac

    return sac.SAC, sac.SAC.get_default_config()


ALGORITHMS = {
    "APPO": _import_appo,
    "BC": _import_bc,
    "CQL": _import_cql,
    "DQN": _import_dqn,
    "DreamerV3": _import_dreamerv3,
    "IMPALA": _import_impala,
    "IQL": _import_iql,
    "MARWIL": _import_marwil,
    "PPO": _import_ppo,
    "SAC": _import_sac,
}


ALGORITHMS_CLASS_TO_NAME = {
    "APPO": "APPO",
    "BC": "BC",
    "CQL": "CQL",
    "DQN": "DQN",
    "DreamerV3": "DreamerV3",
    "Impala": "IMPALA",
    "IQL": "IQL",
    "IMPALA": "IMPALA",
    "MARWIL": "MARWIL",
    "PPO": "PPO",
    "SAC": "SAC",
}


def _get_algorithm_class(alg: str) -> type:
    # This helps us get around a circular import (tune calls rllib._register_all when
    # checking if a rllib Trainable is registered)
    if alg in ALGORITHMS:
        return ALGORITHMS[alg]()[0]
    elif alg == "script":
        from ray.tune import script_runner

        return script_runner.ScriptRunner
    elif alg == "__fake":
        from ray.rllib.algorithms.mock import _MockTrainer

        return _MockTrainer
    elif alg == "__sigmoid_fake_data":
        from ray.rllib.algorithms.mock import _SigmoidFakeData

        return _SigmoidFakeData
    elif alg == "__parameter_tuning":
        from ray.rllib.algorithms.mock import _ParameterTuningTrainer

        return _ParameterTuningTrainer
    else:
        raise Exception("Unknown algorithm {}.".format(alg))


# Dict mapping policy names to where the class is located, relative to rllib.algorithms.
# TODO(jungong) : Finish migrating all the policies to PolicyV2, so we can list
# all the TF eager policies here.
POLICIES = {
    "APPOTF1Policy": "appo.appo_tf_policy",
    "APPOTF2Policy": "appo.appo_tf_policy",
    "APPOTorchPolicy": "appo.appo_torch_policy",
    "CQLTFPolicy": "cql.cql_tf_policy",
    "CQLTorchPolicy": "cql.cql_torch_policy",
    "DQNTFPolicy": "dqn.dqn_tf_policy",
    "DQNTorchPolicy": "dqn.dqn_torch_policy",
    "ImpalaTF1Policy": "impala.impala_tf_policy",
    "ImpalaTF2Policy": "impala.impala_tf_policy",
    "ImpalaTorchPolicy": "impala.impala_torch_policy",
    "MARWILTF1Policy": "marwil.marwil_tf_policy",
    "MARWILTF2Policy": "marwil.marwil_tf_policy",
    "MARWILTorchPolicy": "marwil.marwil_torch_policy",
    "SACTFPolicy": "sac.sac_tf_policy",
    "SACTorchPolicy": "sac.sac_torch_policy",
    "PPOTF1Policy": "ppo.ppo_tf_policy",
    "PPOTF2Policy": "ppo.ppo_tf_policy",
    "PPOTorchPolicy": "ppo.ppo_torch_policy",
}


def get_policy_class_name(policy_class: type):
    """Returns a string name for the provided policy class.

    Args:
        policy_class: RLlib policy class, e.g. A3CTorchPolicy, DQNTFPolicy, etc.

    Returns:
        A string name uniquely mapped to the given policy class.
    """
    # TF2 policy classes may get automatically converted into new class types
    # that have eager tracing capability.
    # These policy classes have the "_traced" postfix in their names.
    # When checkpointing these policy classes, we should save the name of the
    # original policy class instead. So that users have the choice of turning
    # on eager tracing during inference time.
    name = re.sub("_traced$", "", policy_class.__name__)
    if name in POLICIES:
        return name
    return None


def get_policy_class(name: str):
    """Return an actual policy class given the string name.

    Args:
        name: string name of the policy class.

    Returns:
        Actual policy class for the given name.
    """
    if name not in POLICIES:
        return None

    path = POLICIES[name]
    module = importlib.import_module("ray.rllib.algorithms." + path)

    if not hasattr(module, name):
        return None

    return getattr(module, name)
