import logging
import math
from typing import Any, Dict, List, Optional

import numpy as np

from ray.data import Dataset
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
from ray.rllib.offline.offline_evaluation_utils import compute_q_and_v_values
from ray.rllib.offline.offline_evaluator import OfflineEvaluator
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch, convert_ma_batch_to_sample_batch
from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.typing import SampleBatchType

logger = logging.getLogger()


@DeveloperAPI
class DirectMethod(OffPolicyEstimator):
    r"""The Direct Method estimator.

    Let s_t, a_t, and r_t be the state, action, and reward at timestep t.

    This method trains a Q-model for the evaluation policy \pi_e on behavior
    data generated by \pi_b. Currently, RLlib implements this using
    Fitted-Q Evaluation (FQE). You can also implement your own model
    and pass it in as `q_model_config = {"type": your_model_class, **your_kwargs}`.

    This estimator computes the expected return for \pi_e for an episode as:
    V^{\pi_e}(s_0) = \sum_{a \in A} \pi_e(a | s_0) Q(s_0, a)
    and returns the mean and standard deviation over episodes.

    For more information refer to https://arxiv.org/pdf/1911.06854.pdf"""

    @override(OffPolicyEstimator)
    def __init__(
        self,
        policy: Policy,
        gamma: float,
        epsilon_greedy: float = 0.0,
        q_model_config: Optional[Dict] = None,
    ):
        """Initializes a Direct Method OPE Estimator.

        Args:
            policy: Policy to evaluate.
            gamma: Discount factor of the environment.
            epsilon_greedy: The probability by which we act acording to a fully random
                policy during deployment. With 1-epsilon_greedy we act according the
                target policy.
            q_model_config: Arguments to specify the Q-model. Must specify
                a `type` key pointing to the Q-model class.
                This Q-model is trained in the train() method and is used
                to compute the state-value estimates for the DirectMethod estimator.
                It must implement `train` and `estimate_v`.
                TODO (Rohan138): Unify this with RLModule API.
        """

        super().__init__(policy, gamma, epsilon_greedy)

        # Some dummy policies and ones that are not based on a tensor framework
        # backend can come without a config or without a framework key.
        if hasattr(policy, "config"):
            assert (
                policy.config.get("framework", "torch") == "torch"
            ), "Framework must be torch to use DirectMethod."

        q_model_config = q_model_config or {}
        model_cls = q_model_config.pop("type", FQETorchModel)
        self.model = model_cls(
            policy=policy,
            gamma=gamma,
            **q_model_config,
        )
        assert hasattr(
            self.model, "estimate_v"
        ), "self.model must implement `estimate_v`!"

    @override(OffPolicyEstimator)
    def estimate_on_single_episode(self, episode: SampleBatch) -> Dict[str, Any]:
        estimates_per_epsiode = {}
        rewards = episode["rewards"]

        v_behavior = 0.0
        for t in range(episode.count):
            v_behavior += rewards[t] * self.gamma**t

        v_target = self._compute_v_target(episode[:1])

        estimates_per_epsiode["v_behavior"] = v_behavior
        estimates_per_epsiode["v_target"] = v_target

        return estimates_per_epsiode

    @override(OffPolicyEstimator)
    def estimate_on_single_step_samples(
        self, batch: SampleBatch
    ) -> Dict[str, List[float]]:
        estimates_per_epsiode = {}
        rewards = batch["rewards"]

        v_behavior = rewards
        v_target = self._compute_v_target(batch)

        estimates_per_epsiode["v_behavior"] = v_behavior
        estimates_per_epsiode["v_target"] = v_target

        return estimates_per_epsiode

    def _compute_v_target(self, init_step):
        v_target = self.model.estimate_v(init_step)
        v_target = convert_to_numpy(v_target)
        return v_target

    @override(OffPolicyEstimator)
    def train(self, batch: SampleBatchType) -> Dict[str, Any]:
        """Trains self.model on the given batch.

        Args:
            batch: A SampleBatchType to train on

        Returns:
            A dict with key "loss" and value as the mean training loss.
        """
        batch = convert_ma_batch_to_sample_batch(batch)
        losses = self.model.train(batch)
        return {"loss": np.mean(losses)}

    @override(OfflineEvaluator)
    def estimate_on_dataset(
        self, dataset: Dataset, *, n_parallelism: int = ...
    ) -> Dict[str, Any]:
        """Calculates the Direct Method estimate on the given dataset.

        Note: This estimate works for only discrete action spaces for now.

        Args:
            dataset: Dataset to compute the estimate on. Each record in dataset should
                include the following columns: `obs`, `actions`, `action_prob` and
                `rewards`. The `obs` on each row shoud be a vector of D dimensions.
            n_parallelism: The number of parallel workers to use.

        Returns:
            Dictionary with the following keys:
                v_target: The estimated value of the target policy.
                v_behavior: The estimated value of the behavior policy.
                v_gain: The estimated gain of the target policy over the behavior
                    policy.
                v_std: The standard deviation of the estimated value of the target.
        """
        # compute v_values
        batch_size = max(dataset.count() // n_parallelism, 1)
        updated_ds = dataset.map_batches(
            compute_q_and_v_values,
            batch_size=batch_size,
            batch_format="pandas",
            fn_kwargs={
                "model_class": self.model.__class__,
                "model_state": self.model.get_state(),
                "compute_q_values": False,
            },
        )

        v_behavior = updated_ds.mean("rewards")
        v_target = updated_ds.mean("v_values")
        v_gain_mean = v_target / v_behavior
        v_gain_ste = (
            updated_ds.std("v_values") / v_behavior / math.sqrt(dataset.count())
        )

        return {
            "v_behavior": v_behavior,
            "v_target": v_target,
            "v_gain_mean": v_gain_mean,
            "v_gain_ste": v_gain_ste,
        }
