from typing import Any, Callable, Dict, List, Optional

from ray.rllib.callbacks.callbacks import RLlibCallback
from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import OldAPIStack


def make_callback(
    callback_name: str,
    callbacks_objects: Optional[List[RLlibCallback]] = None,
    callbacks_functions: Optional[List[Callable]] = None,
    *,
    args: List[Any] = None,
    kwargs: Dict[str, Any] = None,
) -> None:
    """Calls an RLlibCallback method or a registered callback callable.

    Args:
        callback_name: The name of the callback method or key, for example:
            "on_episode_start" or "on_train_result".
        callbacks_objects: The RLlibCallback object or list of RLlibCallback objects
            to call the `callback_name` method on (in the order they appear in the
            list).
        callbacks_functions: The callable or list of callables to call
            (in the order they appear in the list).
        args: Call args to pass to the method/callable calls.
        kwargs: Call kwargs to pass to the method/callable calls.
    """
    # Loop through all available RLlibCallback objects.
    callbacks_objects = force_list(callbacks_objects)
    for callback_obj in callbacks_objects:
        getattr(callback_obj, callback_name)(*(args or ()), **(kwargs or {}))

    # Loop through all available RLlibCallback objects.
    callbacks_functions = force_list(callbacks_functions)
    for callback_fn in callbacks_functions:
        callback_fn(*(args or ()), **(kwargs or {}))


@OldAPIStack
def _make_multi_callbacks(callback_class_list):
    class _MultiCallbacks(RLlibCallback):
        IS_CALLBACK_CONTAINER = True

        def __init__(self):
            super().__init__()
            self._callback_list = [
                callback_class() for callback_class in callback_class_list
            ]

        def on_algorithm_init(self, **kwargs) -> None:
            for callback in self._callback_list:
                callback.on_algorithm_init(**kwargs)

        def on_workers_recreated(self, **kwargs) -> None:
            for callback in self._callback_list:
                callback.on_workers_recreated(**kwargs)

        # Only on new API stack.
        def on_env_runners_recreated(self, **kwargs) -> None:
            pass

        def on_offline_eval_runners_recreated(self, **kwargs) -> None:
            pass

        def on_checkpoint_loaded(self, **kwargs) -> None:
            for callback in self._callback_list:
                callback.on_checkpoint_loaded(**kwargs)

        def on_create_policy(self, *, policy_id, policy) -> None:
            for callback in self._callback_list:
                callback.on_create_policy(policy_id=policy_id, policy=policy)

        def on_environment_created(self, **kwargs) -> None:
            for callback in self._callback_list:
                callback.on_environment_created(**kwargs)

        def on_sub_environment_created(self, **kwargs) -> None:
            for callback in self._callback_list:
                callback.on_sub_environment_created(**kwargs)

        def on_episode_created(self, **kwargs) -> None:
            for callback in self._callback_list:
                callback.on_episode_created(**kwargs)

        def on_episode_start(self, **kwargs) -> None:
            for callback in self._callback_list:
                callback.on_episode_start(**kwargs)

        def on_episode_step(self, **kwargs) -> None:
            for callback in self._callback_list:
                callback.on_episode_step(**kwargs)

        def on_episode_end(self, **kwargs) -> None:
            for callback in self._callback_list:
                callback.on_episode_end(**kwargs)

        def on_evaluate_start(self, **kwargs) -> None:
            for callback in self._callback_list:
                callback.on_evaluate_start(**kwargs)

        def on_evaluate_end(self, **kwargs) -> None:
            for callback in self._callback_list:
                callback.on_evaluate_end(**kwargs)

        # TODO (simon, sven): Fix the test such that we can simply remove
        # these.
        def on_evaluate_offline_start(self, **kwargs):
            for callback in self._callback_list:
                callback.on_evaluate_offline_start(**kwargs)

        def on_evaluate_offline_end(self, **kwargs):
            for callback in self._callback_list:
                callback.on_evaluate_offline_end(**kwargs)

        def on_postprocess_trajectory(
            self,
            *,
            worker,
            episode,
            agent_id,
            policy_id,
            policies,
            postprocessed_batch,
            original_batches,
            **kwargs,
        ) -> None:
            for callback in self._callback_list:
                callback.on_postprocess_trajectory(
                    worker=worker,
                    episode=episode,
                    agent_id=agent_id,
                    policy_id=policy_id,
                    policies=policies,
                    postprocessed_batch=postprocessed_batch,
                    original_batches=original_batches,
                    **kwargs,
                )

        def on_sample_end(self, **kwargs) -> None:
            for callback in self._callback_list:
                callback.on_sample_end(**kwargs)

        def on_learn_on_batch(
            self, *, policy, train_batch, result: dict, **kwargs
        ) -> None:
            for callback in self._callback_list:
                callback.on_learn_on_batch(
                    policy=policy, train_batch=train_batch, result=result, **kwargs
                )

        def on_train_result(self, **kwargs) -> None:
            for callback in self._callback_list:
                callback.on_train_result(**kwargs)

    return _MultiCallbacks
