import logging
from typing import List, Optional, Union

import tree

from ray.rllib.env.env_runner_group import EnvRunnerGroup
from ray.rllib.policy.sample_batch import (
    DEFAULT_POLICY_ID,
    SampleBatch,
    concat_samples,
)
from ray.rllib.utils.annotations import ExperimentalAPI, OldAPIStack
from ray.rllib.utils.metrics import NUM_AGENT_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED
from ray.rllib.utils.sgd import standardized
from ray.rllib.utils.typing import EpisodeType, SampleBatchType

logger = logging.getLogger(__name__)


@ExperimentalAPI
def synchronous_parallel_sample(
    *,
    worker_set: EnvRunnerGroup,
    max_agent_steps: Optional[int] = None,
    max_env_steps: Optional[int] = None,
    concat: bool = True,
    sample_timeout_s: Optional[float] = None,
    random_actions: bool = False,
    _uses_new_env_runners: bool = False,
    _return_metrics: bool = False,
) -> Union[List[SampleBatchType], SampleBatchType, List[EpisodeType], EpisodeType]:
    """Runs parallel and synchronous rollouts on all remote workers.

    Waits for all workers to return from the remote calls.

    If no remote workers exist (num_workers == 0), use the local worker
    for sampling.

    Alternatively to calling `worker.sample.remote()`, the user can provide a
    `remote_fn()`, which will be applied to the worker(s) instead.

    Args:
        worker_set: The EnvRunnerGroup to use for sampling.
        remote_fn: If provided, use `worker.apply.remote(remote_fn)` instead
            of `worker.sample.remote()` to generate the requests.
        max_agent_steps: Optional number of agent steps to be included in the
            final batch or list of episodes.
        max_env_steps: Optional number of environment steps to be included in the
            final batch or list of episodes.
        concat: Whether to aggregate all resulting batches or episodes. in case of
            batches the list of batches is concatinated at the end. in case of
            episodes all episode lists from workers are flattened into a single list.
        sample_timeout_s: The timeout in sec to use on the `foreach_env_runner` call.
            After this time, the call will return with a result (or not if all
            EnvRunners are stalling). If None, will block indefinitely and not timeout.
        _uses_new_env_runners: Whether the new `EnvRunner API` is used. In this case
            episodes instead of `SampleBatch` objects are returned.

    Returns:
        The list of collected sample batch types or episode types (one for each parallel
        rollout worker in the given `worker_set`).

    .. testcode::

        # Define an RLlib Algorithm.
        from ray.rllib.algorithms.ppo import PPO, PPOConfig
        config = (
            PPOConfig()
            .environment("CartPole-v1")
        )
        algorithm = config.build()
        # 2 remote EnvRunners (num_env_runners=2):
        episodes = synchronous_parallel_sample(
            worker_set=algorithm.env_runner_group,
            _uses_new_env_runners=True,
            concat=False,
        )
        print(len(episodes))

    .. testoutput::

        2
    """
    # Only allow one of `max_agent_steps` or `max_env_steps` to be defined.
    assert not (max_agent_steps is not None and max_env_steps is not None)

    agent_or_env_steps = 0
    max_agent_or_env_steps = max_agent_steps or max_env_steps or None
    sample_batches_or_episodes = []
    all_stats_dicts = []

    random_action_kwargs = {} if not random_actions else {"random_actions": True}

    # Stop collecting batches as soon as one criterium is met.
    while (max_agent_or_env_steps is None and agent_or_env_steps == 0) or (
        max_agent_or_env_steps is not None
        and agent_or_env_steps < max_agent_or_env_steps
    ):
        # No remote workers in the set -> Use local worker for collecting
        # samples.
        if worker_set.num_remote_workers() <= 0:
            sampled_data = [worker_set.local_env_runner.sample(**random_action_kwargs)]
            if _return_metrics:
                stats_dicts = [worker_set.local_env_runner.get_metrics()]
        # Loop over remote workers' `sample()` method in parallel.
        else:
            sampled_data = worker_set.foreach_env_runner(
                (
                    (lambda w: w.sample(**random_action_kwargs))
                    if not _return_metrics
                    else (lambda w: (w.sample(**random_action_kwargs), w.get_metrics()))
                ),
                local_env_runner=False,
                timeout_seconds=sample_timeout_s,
            )
            # Nothing was returned (maybe all workers are stalling) or no healthy
            # remote workers left: Break.
            # There is no point staying in this loop, since we will not be able to
            # get any new samples if we don't have any healthy remote workers left.
            if not sampled_data or worker_set.num_healthy_remote_workers() <= 0:
                if not sampled_data:
                    logger.warning(
                        "No samples returned from remote workers. If you have a "
                        "slow environment or model, consider increasing the "
                        "`sample_timeout_s` or decreasing the "
                        "`rollout_fragment_length` in `AlgorithmConfig.env_runners()."
                    )
                elif worker_set.num_healthy_remote_workers() <= 0:
                    logger.warning(
                        "No healthy remote workers left. Trying to restore workers ..."
                    )
                break

            if _return_metrics:
                stats_dicts = [s[1] for s in sampled_data]
                sampled_data = [s[0] for s in sampled_data]

        # Update our counters for the stopping criterion of the while loop.
        if _return_metrics:
            if max_agent_steps:
                agent_or_env_steps += sum(
                    int(agent_stat)
                    for stat_dict in stats_dicts
                    for agent_stat in stat_dict[NUM_AGENT_STEPS_SAMPLED].values()
                )
            else:
                agent_or_env_steps += sum(
                    int(stat_dict[NUM_ENV_STEPS_SAMPLED]) for stat_dict in stats_dicts
                )
            sample_batches_or_episodes.extend(sampled_data)
            all_stats_dicts.extend(stats_dicts)
        else:
            for batch_or_episode in sampled_data:
                if max_agent_steps:
                    agent_or_env_steps += (
                        sum(e.agent_steps() for e in batch_or_episode)
                        if _uses_new_env_runners
                        else batch_or_episode.agent_steps()
                    )
                else:
                    agent_or_env_steps += (
                        sum(e.env_steps() for e in batch_or_episode)
                        if _uses_new_env_runners
                        else batch_or_episode.env_steps()
                    )
                sample_batches_or_episodes.append(batch_or_episode)
                # Break out (and ignore the remaining samples) if max timesteps (batch
                # size) reached. We want to avoid collecting batches that are too large
                # only because of a failed/restarted worker causing a second iteration
                # of the main loop.
                if (
                    max_agent_or_env_steps is not None
                    and agent_or_env_steps >= max_agent_or_env_steps
                ):
                    break

    if concat is True:
        # If we have episodes flatten the episode list.
        if _uses_new_env_runners:
            sample_batches_or_episodes = tree.flatten(sample_batches_or_episodes)
        # Otherwise we concatenate the `SampleBatch` objects
        else:
            sample_batches_or_episodes = concat_samples(sample_batches_or_episodes)

    if _return_metrics:
        return sample_batches_or_episodes, all_stats_dicts
    return sample_batches_or_episodes


@OldAPIStack
def standardize_fields(samples: SampleBatchType, fields: List[str]) -> SampleBatchType:
    """Standardize fields of the given SampleBatch"""
    wrapped = False

    if isinstance(samples, SampleBatch):
        samples = samples.as_multi_agent()
        wrapped = True

    for policy_id in samples.policy_batches:
        batch = samples.policy_batches[policy_id]
        for field in fields:
            if field in batch:
                batch[field] = standardized(batch[field])

    if wrapped:
        samples = samples.policy_batches[DEFAULT_POLICY_ID]

    return samples
