from typing import List, Tuple

import numpy as np

from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.util.annotations import DeveloperAPI


@DeveloperAPI
def add_one_ts_to_episodes_and_truncate(episodes: List[SingleAgentEpisode]):
    """Adds an artificial timestep to an episode at the end.

    In detail: The last observations, infos, actions, and all `extra_model_outputs`
    will be duplicated and appended to each episode's data. An extra 0.0 reward
    will be appended to the episode's rewards. The episode's timestep will be
    increased by 1. Also, adds the truncated=True flag to each episode if the
    episode is not already done (terminated or truncated).

    Useful for value function bootstrapping, where it is required to compute a
    forward pass for the very last timestep within the episode,
    i.e. using the following input dict: {
      obs=[final obs],
      state=[final state output],
      prev. reward=[final reward],
      etc..
    }

    Args:
        episodes: The list of SingleAgentEpisode objects to extend by one timestep
            and add a truncation flag if necessary.

    Returns:
        A list of the original episodes' truncated values (so the episodes can be
        properly restored later into their original states).
    """
    orig_truncateds = []
    for episode in episodes:
        orig_truncateds.append(episode.is_truncated)

        # Add timestep.
        episode.t += 1
        # Use the episode API that allows appending (possibly complex) structs
        # to the data.
        episode.observations.append(episode.observations[-1])
        episode.infos.append(episode.infos[-1])
        episode.actions.append(episode.actions[-1])
        episode.rewards.append(0.0)
        for v in episode.extra_model_outputs.values():
            v.append(v[-1])
        # Artificially make this episode truncated for the upcoming GAE
        # computations.
        if not episode.is_done:
            episode.is_truncated = True
        # Validate to make sure, everything is in order.
        episode.validate()

    return orig_truncateds


@DeveloperAPI
def remove_last_ts_from_data(
    episode_lens: List[int],
    *data: Tuple[np._typing.NDArray],
) -> Tuple[np._typing.NDArray]:
    """Removes the last timesteps from each given data item.

    Each item in data is a concatenated sequence of episodes data.
    For example if `episode_lens` is [2, 4], then data is a shape=(6,)
    ndarray. The returned corresponding value will have shape (4,), meaning
    both episodes have been shortened by exactly one timestep to 1 and 3.

    ..testcode::

        from ray.rllib.algorithms.ppo.ppo_learner import PPOLearner
        import numpy as np

        unpadded = PPOLearner._remove_last_ts_from_data(
            [5, 3],
            np.array([0, 1, 2, 3, 4,  0, 1, 2]),
        )
        assert (unpadded[0] == [0, 1, 2, 3, 0, 1]).all()

        unpadded = PPOLearner._remove_last_ts_from_data(
            [4, 2, 3],
            np.array([0, 1, 2, 3,  0, 1,  0, 1, 2]),
            np.array([4, 5, 6, 7,  2, 3,  3, 4, 5]),
        )
        assert (unpadded[0] == [0, 1, 2,  0,  0, 1]).all()
        assert (unpadded[1] == [4, 5, 6,  2,  3, 4]).all()

    Args:
        episode_lens: A list of current episode lengths. The returned
            data will have the same lengths minus 1 timestep.
        data: A tuple of data items (np.ndarrays) representing concatenated episodes
            to be shortened by one timestep per episode.
            Note that only arrays with `shape=(n,)` are supported! The
            returned data will have `shape=(n-len(episode_lens),)` (each
            episode gets shortened by one timestep).

    Returns:
        A tuple of new data items shortened by one timestep.
    """
    # Figure out the new slices to apply to each data item based on
    # the given episode_lens.
    slices = []
    sum = 0
    for len_ in episode_lens:
        slices.append(slice(sum, sum + len_ - 1))
        sum += len_
    # Compiling return data by slicing off one timestep at the end of
    # each episode.
    ret = []
    for d in data:
        ret.append(np.concatenate([d[s] for s in slices]))
    return tuple(ret) if len(ret) > 1 else ret[0]


@DeveloperAPI
def remove_last_ts_from_episodes_and_restore_truncateds(
    episodes: List[SingleAgentEpisode],
    orig_truncateds: List[bool],
) -> None:
    """Reverts the effects of `_add_ts_to_episodes_and_truncate`.

    Args:
        episodes: The list of SingleAgentEpisode objects to extend by one timestep
            and add a truncation flag if necessary.
        orig_truncateds: A list of the original episodes' truncated values to be
            applied to the `episodes`.
    """

    # Fix all episodes.
    for episode, orig_truncated in zip(episodes, orig_truncateds):
        # Reduce timesteps by 1.
        episode.t -= 1
        # Remove all extra timestep data from the episode's buffers.
        episode.observations.pop()
        episode.infos.pop()
        episode.actions.pop()
        episode.rewards.pop()
        for v in episode.extra_model_outputs.values():
            v.pop()
        # Fix the truncateds flag again.
        episode.is_truncated = orig_truncated
