import copy
import functools
import logging
import math
import os
import threading
import time
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Optional,
    Set,
    Tuple,
    Type,
    Union,
)

import gymnasium as gym
import numpy as np
import tree  # pip install dm_tree

import ray
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.policy import Policy, PolicyState
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import NullContextManager, force_list
from ray.rllib.utils.annotations import OldAPIStack, override
from ray.rllib.utils.error import ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics import (
    DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
    NUM_AGENT_STEPS_TRAINED,
    NUM_GRAD_UPDATES_LIFETIME,
)
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.spaces.space_utils import normalize_action
from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
from ray.rllib.utils.typing import (
    AlgorithmConfigDict,
    GradInfoDict,
    ModelGradients,
    ModelWeights,
    TensorStructType,
    TensorType,
)

torch, nn = try_import_torch()

logger = logging.getLogger(__name__)


@OldAPIStack
class TorchPolicy(Policy):
    """PyTorch specific Policy class to use with RLlib."""

    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        config: AlgorithmConfigDict,
        *,
        model: Optional[TorchModelV2] = None,
        loss: Optional[
            Callable[
                [Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch],
                Union[TensorType, List[TensorType]],
            ]
        ] = None,
        action_distribution_class: Optional[Type[TorchDistributionWrapper]] = None,
        action_sampler_fn: Optional[
            Callable[
                [TensorType, List[TensorType]],
                Union[
                    Tuple[TensorType, TensorType, List[TensorType]],
                    Tuple[TensorType, TensorType, TensorType, List[TensorType]],
                ],
            ]
        ] = None,
        action_distribution_fn: Optional[
            Callable[
                [Policy, ModelV2, TensorType, TensorType, TensorType],
                Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]],
            ]
        ] = None,
        max_seq_len: int = 20,
        get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
    ):
        """Initializes a TorchPolicy instance.

        Args:
            observation_space: Observation space of the policy.
            action_space: Action space of the policy.
            config: The Policy's config dict.
            model: PyTorch policy module. Given observations as
                input, this module must return a list of outputs where the
                first item is action logits, and the rest can be any value.
            loss: Callable that returns one or more (a list of) scalar loss
                terms.
            action_distribution_class: Class for a torch action distribution.
            action_sampler_fn: A callable returning either a sampled action,
                its log-likelihood and updated state or a sampled action, its
                log-likelihood, updated state and action distribution inputs
                given Policy, ModelV2, input_dict, state batches (optional),
                explore, and timestep. Provide `action_sampler_fn` if you would
                like to have full control over the action computation step,
                including the model forward pass, possible sampling from a
                distribution, and exploration logic.
                Note: If `action_sampler_fn` is given, `action_distribution_fn`
                must be None. If both `action_sampler_fn` and
                `action_distribution_fn` are None, RLlib will simply pass
                inputs through `self.model` to get distribution inputs, create
                the distribution object, sample from it, and apply some
                exploration logic to the results.
                The callable takes as inputs: Policy, ModelV2, input_dict
                (SampleBatch), state_batches (optional), explore, and timestep.
            action_distribution_fn: A callable returning distribution inputs
                (parameters), a dist-class to generate an action distribution
                object from, and internal-state outputs (or an empty list if
                not applicable).
                Provide `action_distribution_fn` if you would like to only
                customize the model forward pass call. The resulting
                distribution parameters are then used by RLlib to create a
                distribution object, sample from it, and execute any
                exploration logic.
                Note: If `action_distribution_fn` is given, `action_sampler_fn`
                must be None. If both `action_sampler_fn` and
                `action_distribution_fn` are None, RLlib will simply pass
                inputs through `self.model` to get distribution inputs, create
                the distribution object, sample from it, and apply some
                exploration logic to the results.
                The callable takes as inputs: Policy, ModelV2, ModelInputDict,
                explore, timestep, is_training.
            max_seq_len: Max sequence length for LSTM training.
            get_batch_divisibility_req: Optional callable that returns the
                divisibility requirement for sample batches given the Policy.
        """
        self.framework = config["framework"] = "torch"
        self._loss_initialized = False
        super().__init__(observation_space, action_space, config)

        # Create multi-GPU model towers, if necessary.
        # - The central main model will be stored under self.model, residing
        #   on self.device (normally, a CPU).
        # - Each GPU will have a copy of that model under
        #   self.model_gpu_towers, matching the devices in self.devices.
        # - Parallelization is done by splitting the train batch and passing
        #   it through the model copies in parallel, then averaging over the
        #   resulting gradients, applying these averages on the main model and
        #   updating all towers' weights from the main model.
        # - In case of just one device (1 (fake or real) GPU or 1 CPU), no
        #   parallelization will be done.

        # If no Model is provided, build a default one here.
        if model is None:
            dist_class, logit_dim = ModelCatalog.get_action_dist(
                action_space, self.config["model"], framework=self.framework
            )
            model = ModelCatalog.get_model_v2(
                obs_space=self.observation_space,
                action_space=self.action_space,
                num_outputs=logit_dim,
                model_config=self.config["model"],
                framework=self.framework,
            )
            if action_distribution_class is None:
                action_distribution_class = dist_class

        # Get devices to build the graph on.
        num_gpus = self._get_num_gpus_for_policy()
        gpu_ids = list(range(torch.cuda.device_count()))
        logger.info(f"Found {len(gpu_ids)} visible cuda devices.")

        # Place on one or more CPU(s) when either:
        # - Fake GPU mode.
        # - num_gpus=0 (either set by user or we are in local_mode=True).
        # - No GPUs available.
        if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
            self.device = torch.device("cpu")
            self.devices = [self.device for _ in range(int(math.ceil(num_gpus)) or 1)]
            self.model_gpu_towers = [
                model if i == 0 else copy.deepcopy(model)
                for i in range(int(math.ceil(num_gpus)) or 1)
            ]
            if hasattr(self, "target_model"):
                self.target_models = {
                    m: self.target_model for m in self.model_gpu_towers
                }
            self.model = model
        # Place on one or more actual GPU(s), when:
        # - num_gpus > 0 (set by user) AND
        # - local_mode=False AND
        # - actual GPUs available AND
        # - non-fake GPU mode.
        else:
            # We are a remote worker (WORKER_MODE=1):
            # GPUs should be assigned to us by ray.
            if ray._private.worker._mode() == ray._private.worker.WORKER_MODE:
                gpu_ids = ray.get_gpu_ids()

            if len(gpu_ids) < num_gpus:
                raise ValueError(
                    "TorchPolicy was not able to find enough GPU IDs! Found "
                    f"{gpu_ids}, but num_gpus={num_gpus}."
                )

            self.devices = [
                torch.device("cuda:{}".format(i))
                for i, id_ in enumerate(gpu_ids)
                if i < num_gpus
            ]
            self.device = self.devices[0]
            ids = [id_ for i, id_ in enumerate(gpu_ids) if i < num_gpus]
            self.model_gpu_towers = []
            for i, _ in enumerate(ids):
                model_copy = copy.deepcopy(model)
                self.model_gpu_towers.append(model_copy.to(self.devices[i]))
            if hasattr(self, "target_model"):
                self.target_models = {
                    m: copy.deepcopy(self.target_model).to(self.devices[i])
                    for i, m in enumerate(self.model_gpu_towers)
                }
            self.model = self.model_gpu_towers[0]

        # Lock used for locking some methods on the object-level.
        # This prevents possible race conditions when calling the model
        # first, then its value function (e.g. in a loss function), in
        # between of which another model call is made (e.g. to compute an
        # action).
        self._lock = threading.RLock()

        self._state_inputs = self.model.get_initial_state()
        self._is_recurrent = len(self._state_inputs) > 0
        # Auto-update model's inference view requirements, if recurrent.
        self._update_model_view_requirements_from_init_state()
        # Combine view_requirements for Model and Policy.
        self.view_requirements.update(self.model.view_requirements)

        self.exploration = self._create_exploration()
        self.unwrapped_model = model  # used to support DistributedDataParallel
        # To ensure backward compatibility:
        # Old way: If `loss` provided here, use as-is (as a function).
        if loss is not None:
            self._loss = loss
        # New way: Convert the overridden `self.loss` into a plain function,
        # so it can be called the same way as `loss` would be, ensuring
        # backward compatibility.
        elif self.loss.__func__.__qualname__ != "Policy.loss":
            self._loss = self.loss.__func__
        # `loss` not provided nor overridden from Policy -> Set to None.
        else:
            self._loss = None
        self._optimizers = force_list(self.optimizer())
        # Store, which params (by index within the model's list of
        # parameters) should be updated per optimizer.
        # Maps optimizer idx to set or param indices.
        self.multi_gpu_param_groups: List[Set[int]] = []
        main_params = {p: i for i, p in enumerate(self.model.parameters())}
        for o in self._optimizers:
            param_indices = []
            for pg_idx, pg in enumerate(o.param_groups):
                for p in pg["params"]:
                    param_indices.append(main_params[p])
            self.multi_gpu_param_groups.append(set(param_indices))

        # Create n sample-batch buffers (num_multi_gpu_tower_stacks), each
        # one with m towers (num_gpus).
        num_buffers = self.config.get("num_multi_gpu_tower_stacks", 1)
        self._loaded_batches = [[] for _ in range(num_buffers)]

        self.dist_class = action_distribution_class
        self.action_sampler_fn = action_sampler_fn
        self.action_distribution_fn = action_distribution_fn

        # If set, means we are using distributed allreduce during learning.
        self.distributed_world_size = None

        self.max_seq_len = max_seq_len
        self.batch_divisibility_req = (
            get_batch_divisibility_req(self)
            if callable(get_batch_divisibility_req)
            else (get_batch_divisibility_req or 1)
        )

    @override(Policy)
    def compute_actions_from_input_dict(
        self,
        input_dict: Dict[str, TensorType],
        explore: bool = None,
        timestep: Optional[int] = None,
        **kwargs,
    ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
        with torch.no_grad():
            # Pass lazy (torch) tensor dict to Model as `input_dict`.
            input_dict = self._lazy_tensor_dict(input_dict)
            input_dict.set_training(True)
            # Pack internal state inputs into (separate) list.
            state_batches = [
                input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
            ]
            # Calculate RNN sequence lengths.
            seq_lens = (
                torch.tensor(
                    [1] * len(state_batches[0]),
                    dtype=torch.long,
                    device=state_batches[0].device,
                )
                if state_batches
                else None
            )

            return self._compute_action_helper(
                input_dict, state_batches, seq_lens, explore, timestep
            )

    @override(Policy)
    def compute_actions(
        self,
        obs_batch: Union[List[TensorStructType], TensorStructType],
        state_batches: Optional[List[TensorType]] = None,
        prev_action_batch: Union[List[TensorStructType], TensorStructType] = None,
        prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None,
        info_batch: Optional[Dict[str, list]] = None,
        episodes=None,
        explore: Optional[bool] = None,
        timestep: Optional[int] = None,
        **kwargs,
    ) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
        with torch.no_grad():
            seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
            input_dict = self._lazy_tensor_dict(
                {
                    SampleBatch.CUR_OBS: obs_batch,
                    "is_training": False,
                }
            )
            if prev_action_batch is not None:
                input_dict[SampleBatch.PREV_ACTIONS] = np.asarray(prev_action_batch)
            if prev_reward_batch is not None:
                input_dict[SampleBatch.PREV_REWARDS] = np.asarray(prev_reward_batch)
            state_batches = [
                convert_to_torch_tensor(s, self.device) for s in (state_batches or [])
            ]
            return self._compute_action_helper(
                input_dict, state_batches, seq_lens, explore, timestep
            )

    @with_lock
    @override(Policy)
    def compute_log_likelihoods(
        self,
        actions: Union[List[TensorStructType], TensorStructType],
        obs_batch: Union[List[TensorStructType], TensorStructType],
        state_batches: Optional[List[TensorType]] = None,
        prev_action_batch: Optional[
            Union[List[TensorStructType], TensorStructType]
        ] = None,
        prev_reward_batch: Optional[
            Union[List[TensorStructType], TensorStructType]
        ] = None,
        actions_normalized: bool = True,
        **kwargs,
    ) -> TensorType:
        if self.action_sampler_fn and self.action_distribution_fn is None:
            raise ValueError(
                "Cannot compute log-prob/likelihood w/o an "
                "`action_distribution_fn` and a provided "
                "`action_sampler_fn`!"
            )

        with torch.no_grad():
            input_dict = self._lazy_tensor_dict(
                {SampleBatch.CUR_OBS: obs_batch, SampleBatch.ACTIONS: actions}
            )
            if prev_action_batch is not None:
                input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
            if prev_reward_batch is not None:
                input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
            seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
            state_batches = [
                convert_to_torch_tensor(s, self.device) for s in (state_batches or [])
            ]

            # Exploration hook before each forward pass.
            self.exploration.before_compute_actions(explore=False)

            # Action dist class and inputs are generated via custom function.
            if self.action_distribution_fn:
                # Try new action_distribution_fn signature, supporting
                # state_batches and seq_lens.
                try:
                    dist_inputs, dist_class, state_out = self.action_distribution_fn(
                        self,
                        self.model,
                        input_dict=input_dict,
                        state_batches=state_batches,
                        seq_lens=seq_lens,
                        explore=False,
                        is_training=False,
                    )
                # Trying the old way (to stay backward compatible).
                # TODO: Remove in future.
                except TypeError as e:
                    if (
                        "positional argument" in e.args[0]
                        or "unexpected keyword argument" in e.args[0]
                    ):
                        dist_inputs, dist_class, _ = self.action_distribution_fn(
                            policy=self,
                            model=self.model,
                            obs_batch=input_dict[SampleBatch.CUR_OBS],
                            explore=False,
                            is_training=False,
                        )
                    else:
                        raise e

            # Default action-dist inputs calculation.
            else:
                dist_class = self.dist_class
                dist_inputs, _ = self.model(input_dict, state_batches, seq_lens)

            action_dist = dist_class(dist_inputs, self.model)

            # Normalize actions if necessary.
            actions = input_dict[SampleBatch.ACTIONS]
            if not actions_normalized and self.config["normalize_actions"]:
                actions = normalize_action(actions, self.action_space_struct)

            log_likelihoods = action_dist.logp(actions)

            return log_likelihoods

    @with_lock
    @override(Policy)
    def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
        # Set Model to train mode.
        if self.model:
            self.model.train()
        # Callback handling.
        learn_stats = {}
        self.callbacks.on_learn_on_batch(
            policy=self, train_batch=postprocessed_batch, result=learn_stats
        )

        # Compute gradients (will calculate all losses and `backward()`
        # them to get the grads).
        grads, fetches = self.compute_gradients(postprocessed_batch)

        # Step the optimizers.
        self.apply_gradients(_directStepOptimizerSingleton)

        self.num_grad_updates += 1

        if self.model:
            fetches["model"] = self.model.metrics()

        fetches.update(
            {
                "custom_metrics": learn_stats,
                NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
                NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
                # -1, b/c we have to measure this diff before we do the update above.
                DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
                    self.num_grad_updates
                    - 1
                    - (postprocessed_batch.num_grad_updates or 0)
                ),
            }
        )

        return fetches

    @override(Policy)
    def load_batch_into_buffer(
        self,
        batch: SampleBatch,
        buffer_index: int = 0,
    ) -> int:
        # Set the is_training flag of the batch.
        batch.set_training(True)

        # Shortcut for 1 CPU only: Store batch in `self._loaded_batches`.
        if len(self.devices) == 1 and self.devices[0].type == "cpu":
            assert buffer_index == 0
            pad_batch_to_sequences_of_same_size(
                batch=batch,
                max_seq_len=self.max_seq_len,
                shuffle=False,
                batch_divisibility_req=self.batch_divisibility_req,
                view_requirements=self.view_requirements,
            )
            self._lazy_tensor_dict(batch)
            self._loaded_batches[0] = [batch]
            return len(batch)

        # Batch (len=28, seq-lens=[4, 7, 4, 10, 3]):
        # 0123 0123456 0123 0123456789ABC

        # 1) split into n per-GPU sub batches (n=2).
        # [0123 0123456] [012] [3 0123456789 ABC]
        # (len=14, 14 seq-lens=[4, 7, 3] [1, 10, 3])
        slices = batch.timeslices(num_slices=len(self.devices))

        # 2) zero-padding (max-seq-len=10).
        # - [0123000000 0123456000 0120000000]
        # - [3000000000 0123456789 ABC0000000]
        for slice in slices:
            pad_batch_to_sequences_of_same_size(
                batch=slice,
                max_seq_len=self.max_seq_len,
                shuffle=False,
                batch_divisibility_req=self.batch_divisibility_req,
                view_requirements=self.view_requirements,
            )

        # 3) Load splits into the given buffer (consisting of n GPUs).
        slices = [slice.to_device(self.devices[i]) for i, slice in enumerate(slices)]
        self._loaded_batches[buffer_index] = slices

        # Return loaded samples per-device.
        return len(slices[0])

    @override(Policy)
    def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
        if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
            assert buffer_index == 0
        return sum(len(b) for b in self._loaded_batches[buffer_index])

    @override(Policy)
    def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
        if not self._loaded_batches[buffer_index]:
            raise ValueError(
                "Must call Policy.load_batch_into_buffer() before "
                "Policy.learn_on_loaded_batch()!"
            )

        # Get the correct slice of the already loaded batch to use,
        # based on offset and batch size.
        device_batch_size = self.config.get("minibatch_size")
        if device_batch_size is None:
            device_batch_size = self.config.get(
                "sgd_minibatch_size",
                self.config["train_batch_size"],
            )
        device_batch_size //= len(self.devices)

        # Set Model to train mode.
        if self.model_gpu_towers:
            for t in self.model_gpu_towers:
                t.train()

        # Shortcut for 1 CPU only: Batch should already be stored in
        # `self._loaded_batches`.
        if len(self.devices) == 1 and self.devices[0].type == "cpu":
            assert buffer_index == 0
            if device_batch_size >= len(self._loaded_batches[0][0]):
                batch = self._loaded_batches[0][0]
            else:
                batch = self._loaded_batches[0][0][offset : offset + device_batch_size]
            return self.learn_on_batch(batch)

        if len(self.devices) > 1:
            # Copy weights of main model (tower-0) to all other towers.
            state_dict = self.model.state_dict()
            # Just making sure tower-0 is really the same as self.model.
            assert self.model_gpu_towers[0] is self.model
            for tower in self.model_gpu_towers[1:]:
                tower.load_state_dict(state_dict)

        if device_batch_size >= sum(len(s) for s in self._loaded_batches[buffer_index]):
            device_batches = self._loaded_batches[buffer_index]
        else:
            device_batches = [
                b[offset : offset + device_batch_size]
                for b in self._loaded_batches[buffer_index]
            ]

        # Callback handling.
        batch_fetches = {}
        for i, batch in enumerate(device_batches):
            custom_metrics = {}
            self.callbacks.on_learn_on_batch(
                policy=self, train_batch=batch, result=custom_metrics
            )
            batch_fetches[f"tower_{i}"] = {"custom_metrics": custom_metrics}

        # Do the (maybe parallelized) gradient calculation step.
        tower_outputs = self._multi_gpu_parallel_grad_calc(device_batches)

        # Mean-reduce gradients over GPU-towers (do this on CPU: self.device).
        all_grads = []
        for i in range(len(tower_outputs[0][0])):
            if tower_outputs[0][0][i] is not None:
                all_grads.append(
                    torch.mean(
                        torch.stack([t[0][i].to(self.device) for t in tower_outputs]),
                        dim=0,
                    )
                )
            else:
                all_grads.append(None)
        # Set main model's grads to mean-reduced values.
        for i, p in enumerate(self.model.parameters()):
            p.grad = all_grads[i]

        self.apply_gradients(_directStepOptimizerSingleton)

        self.num_grad_updates += 1

        for i, (model, batch) in enumerate(zip(self.model_gpu_towers, device_batches)):
            batch_fetches[f"tower_{i}"].update(
                {
                    LEARNER_STATS_KEY: self.extra_grad_info(batch),
                    "model": model.metrics(),
                    NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
                    # -1, b/c we have to measure this diff before we do the update
                    # above.
                    DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
                        self.num_grad_updates - 1 - (batch.num_grad_updates or 0)
                    ),
                }
            )
        batch_fetches.update(self.extra_compute_grad_fetches())

        return batch_fetches

    @with_lock
    @override(Policy)
    def compute_gradients(self, postprocessed_batch: SampleBatch) -> ModelGradients:
        assert len(self.devices) == 1

        # If not done yet, see whether we have to zero-pad this batch.
        if not postprocessed_batch.zero_padded:
            pad_batch_to_sequences_of_same_size(
                batch=postprocessed_batch,
                max_seq_len=self.max_seq_len,
                shuffle=False,
                batch_divisibility_req=self.batch_divisibility_req,
                view_requirements=self.view_requirements,
            )

        postprocessed_batch.set_training(True)
        self._lazy_tensor_dict(postprocessed_batch, device=self.devices[0])

        # Do the (maybe parallelized) gradient calculation step.
        tower_outputs = self._multi_gpu_parallel_grad_calc([postprocessed_batch])

        all_grads, grad_info = tower_outputs[0]

        grad_info["allreduce_latency"] /= len(self._optimizers)
        grad_info.update(self.extra_grad_info(postprocessed_batch))

        fetches = self.extra_compute_grad_fetches()

        return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})

    @override(Policy)
    def apply_gradients(self, gradients: ModelGradients) -> None:
        if gradients == _directStepOptimizerSingleton:
            for i, opt in enumerate(self._optimizers):
                opt.step()
        else:
            # TODO(sven): Not supported for multiple optimizers yet.
            assert len(self._optimizers) == 1
            for g, p in zip(gradients, self.model.parameters()):
                if g is not None:
                    if torch.is_tensor(g):
                        p.grad = g.to(self.device)
                    else:
                        p.grad = torch.from_numpy(g).to(self.device)

            self._optimizers[0].step()

    def get_tower_stats(self, stats_name: str) -> List[TensorStructType]:
        """Returns list of per-tower stats, copied to this Policy's device.

        Args:
            stats_name: The name of the stats to average over (this str
                must exist as a key inside each tower's `tower_stats` dict).

        Returns:
            The list of stats tensor (structs) of all towers, copied to this
            Policy's device.

        Raises:
            AssertionError: If the `stats_name` cannot be found in any one
            of the tower's `tower_stats` dicts.
        """
        data = []
        for tower in self.model_gpu_towers:
            if stats_name in tower.tower_stats:
                data.append(
                    tree.map_structure(
                        lambda s: s.to(self.device), tower.tower_stats[stats_name]
                    )
                )
        assert len(data) > 0, (
            f"Stats `{stats_name}` not found in any of the towers (you have "
            f"{len(self.model_gpu_towers)} towers in total)! Make "
            "sure you call the loss function on at least one of the towers."
        )
        return data

    @override(Policy)
    def get_weights(self) -> ModelWeights:
        return {k: v.cpu().detach().numpy() for k, v in self.model.state_dict().items()}

    @override(Policy)
    def set_weights(self, weights: ModelWeights) -> None:
        weights = convert_to_torch_tensor(weights, device=self.device)
        self.model.load_state_dict(weights)

    @override(Policy)
    def is_recurrent(self) -> bool:
        return self._is_recurrent

    @override(Policy)
    def num_state_tensors(self) -> int:
        return len(self.model.get_initial_state())

    @override(Policy)
    def get_initial_state(self) -> List[TensorType]:
        return [s.detach().cpu().numpy() for s in self.model.get_initial_state()]

    @override(Policy)
    def get_state(self) -> PolicyState:
        state = super().get_state()

        state["_optimizer_variables"] = []
        for i, o in enumerate(self._optimizers):
            optim_state_dict = convert_to_numpy(o.state_dict())
            state["_optimizer_variables"].append(optim_state_dict)
        # Add exploration state.
        if self.exploration:
            # This is not compatible with RLModules, which have a method
            # `forward_exploration` to specify custom exploration behavior.
            state["_exploration_state"] = self.exploration.get_state()
        return state

    @override(Policy)
    def set_state(self, state: PolicyState) -> None:
        # Set optimizer vars first.
        optimizer_vars = state.get("_optimizer_variables", None)
        if optimizer_vars:
            assert len(optimizer_vars) == len(self._optimizers)
            for o, s in zip(self._optimizers, optimizer_vars):
                # Torch optimizer param_groups include things like beta, etc. These
                # parameters should be left as scalar and not converted to tensors.
                # otherwise, torch.optim.step() will start to complain.
                optim_state_dict = {"param_groups": s["param_groups"]}
                optim_state_dict["state"] = convert_to_torch_tensor(
                    s["state"], device=self.device
                )
                o.load_state_dict(optim_state_dict)
        # Set exploration's state.
        if hasattr(self, "exploration") and "_exploration_state" in state:
            self.exploration.set_state(state=state["_exploration_state"])

        # Restore global timestep.
        self.global_timestep = state["global_timestep"]

        # Then the Policy's (NN) weights and connectors.
        super().set_state(state)

    def extra_grad_process(
        self, optimizer: "torch.optim.Optimizer", loss: TensorType
    ) -> Dict[str, TensorType]:
        """Called after each optimizer.zero_grad() + loss.backward() call.

        Called for each self._optimizers/loss-value pair.
        Allows for gradient processing before optimizer.step() is called.
        E.g. for gradient clipping.

        Args:
            optimizer: A torch optimizer object.
            loss: The loss tensor associated with the optimizer.

        Returns:
            An dict with information on the gradient processing step.
        """
        return {}

    def extra_compute_grad_fetches(self) -> Dict[str, Any]:
        """Extra values to fetch and return from compute_gradients().

        Returns:
            Extra fetch dict to be added to the fetch dict of the
            `compute_gradients` call.
        """
        return {LEARNER_STATS_KEY: {}}  # e.g, stats, td error, etc.

    def extra_action_out(
        self,
        input_dict: Dict[str, TensorType],
        state_batches: List[TensorType],
        model: TorchModelV2,
        action_dist: TorchDistributionWrapper,
    ) -> Dict[str, TensorType]:
        """Returns dict of extra info to include in experience batch.

        Args:
            input_dict: Dict of model input tensors.
            state_batches: List of state tensors.
            model: Reference to the model object.
            action_dist: Torch action dist object
                to get log-probs (e.g. for already sampled actions).

        Returns:
            Extra outputs to return in a `compute_actions_from_input_dict()`
            call (3rd return value).
        """
        return {}

    def extra_grad_info(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
        """Return dict of extra grad info.

        Args:
            train_batch: The training batch for which to produce
                extra grad info for.

        Returns:
            The info dict carrying grad info per str key.
        """
        return {}

    def optimizer(
        self,
    ) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
        """Custom the local PyTorch optimizer(s) to use.

        Returns:
            The local PyTorch optimizer(s) to use for this Policy.
        """
        if hasattr(self, "config"):
            optimizers = [
                torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])
            ]
        else:
            optimizers = [torch.optim.Adam(self.model.parameters())]
        if self.exploration:
            optimizers = self.exploration.get_exploration_optimizer(optimizers)
        return optimizers

    @override(Policy)
    def export_model(self, export_dir: str, onnx: Optional[int] = None) -> None:
        """Exports the Policy's Model to local directory for serving.

        Creates a TorchScript model and saves it.

        Args:
            export_dir: Local writable directory or filename.
            onnx: If given, will export model in ONNX format. The
                value of this parameter set the ONNX OpSet version to use.
        """
        os.makedirs(export_dir, exist_ok=True)

        if onnx:
            self._lazy_tensor_dict(self._dummy_batch)
            # Provide dummy state inputs if not an RNN (torch cannot jit with
            # returned empty internal states list).
            if "state_in_0" not in self._dummy_batch:
                self._dummy_batch["state_in_0"] = self._dummy_batch[
                    SampleBatch.SEQ_LENS
                ] = np.array([1.0])
            seq_lens = self._dummy_batch[SampleBatch.SEQ_LENS]

            state_ins = []
            i = 0
            while "state_in_{}".format(i) in self._dummy_batch:
                state_ins.append(self._dummy_batch["state_in_{}".format(i)])
                i += 1
            dummy_inputs = {
                k: self._dummy_batch[k]
                for k in self._dummy_batch.keys()
                if k != "is_training"
            }

            file_name = os.path.join(export_dir, "model.onnx")
            torch.onnx.export(
                self.model,
                (dummy_inputs, state_ins, seq_lens),
                file_name,
                export_params=True,
                opset_version=onnx,
                do_constant_folding=True,
                input_names=list(dummy_inputs.keys())
                + ["state_ins", SampleBatch.SEQ_LENS],
                output_names=["output", "state_outs"],
                dynamic_axes={
                    k: {0: "batch_size"}
                    for k in list(dummy_inputs.keys())
                    + ["state_ins", SampleBatch.SEQ_LENS]
                },
            )
        # Save the torch.Model (architecture and weights, so it can be retrieved
        # w/o access to the original (custom) Model or Policy code).
        else:
            filename = os.path.join(export_dir, "model.pt")
            try:
                torch.save(self.model, f=filename)
            except Exception:
                if os.path.exists(filename):
                    os.remove(filename)
                logger.warning(ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL)

    @override(Policy)
    def import_model_from_h5(self, import_file: str) -> None:
        """Imports weights into torch model."""
        return self.model.import_from_h5(import_file)

    @with_lock
    def _compute_action_helper(
        self, input_dict, state_batches, seq_lens, explore, timestep
    ):
        """Shared forward pass logic (w/ and w/o trajectory view API).

        Returns:
            A tuple consisting of a) actions, b) state_out, c) extra_fetches.
        """
        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep
        self._is_recurrent = state_batches is not None and state_batches != []

        # Switch to eval mode.
        if self.model:
            self.model.eval()

        if self.action_sampler_fn:
            action_dist = dist_inputs = None
            action_sampler_outputs = self.action_sampler_fn(
                self,
                self.model,
                input_dict,
                state_batches,
                explore=explore,
                timestep=timestep,
            )
            if len(action_sampler_outputs) == 4:
                actions, logp, dist_inputs, state_out = action_sampler_outputs
            else:
                actions, logp, state_out = action_sampler_outputs
        else:
            # Call the exploration before_compute_actions hook.
            self.exploration.before_compute_actions(explore=explore, timestep=timestep)
            if self.action_distribution_fn:
                # Try new action_distribution_fn signature, supporting
                # state_batches and seq_lens.
                try:
                    dist_inputs, dist_class, state_out = self.action_distribution_fn(
                        self,
                        self.model,
                        input_dict=input_dict,
                        state_batches=state_batches,
                        seq_lens=seq_lens,
                        explore=explore,
                        timestep=timestep,
                        is_training=False,
                    )
                # Trying the old way (to stay backward compatible).
                # TODO: Remove in future.
                except TypeError as e:
                    if (
                        "positional argument" in e.args[0]
                        or "unexpected keyword argument" in e.args[0]
                    ):
                        (
                            dist_inputs,
                            dist_class,
                            state_out,
                        ) = self.action_distribution_fn(
                            self,
                            self.model,
                            input_dict[SampleBatch.CUR_OBS],
                            explore=explore,
                            timestep=timestep,
                            is_training=False,
                        )
                    else:
                        raise e
            else:
                dist_class = self.dist_class
                dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)

            if not (
                isinstance(dist_class, functools.partial)
                or issubclass(dist_class, TorchDistributionWrapper)
            ):
                raise ValueError(
                    "`dist_class` ({}) not a TorchDistributionWrapper "
                    "subclass! Make sure your `action_distribution_fn` or "
                    "`make_model_and_action_dist` return a correct "
                    "distribution class.".format(dist_class.__name__)
                )
            action_dist = dist_class(dist_inputs, self.model)

            # Get the exploration action from the forward results.
            actions, logp = self.exploration.get_exploration_action(
                action_distribution=action_dist, timestep=timestep, explore=explore
            )

        input_dict[SampleBatch.ACTIONS] = actions

        # Add default and custom fetches.
        extra_fetches = self.extra_action_out(
            input_dict, state_batches, self.model, action_dist
        )

        # Action-dist inputs.
        if dist_inputs is not None:
            extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs

        # Action-logp and action-prob.
        if logp is not None:
            extra_fetches[SampleBatch.ACTION_PROB] = torch.exp(logp.float())
            extra_fetches[SampleBatch.ACTION_LOGP] = logp

        # Update our global timestep by the batch size.
        self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])

        return convert_to_numpy((actions, state_out, extra_fetches))

    def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None):
        # TODO: (sven): Keep for a while to ensure backward compatibility.
        if not isinstance(postprocessed_batch, SampleBatch):
            postprocessed_batch = SampleBatch(postprocessed_batch)
        postprocessed_batch.set_get_interceptor(
            functools.partial(convert_to_torch_tensor, device=device or self.device)
        )
        return postprocessed_batch

    def _multi_gpu_parallel_grad_calc(
        self, sample_batches: List[SampleBatch]
    ) -> List[Tuple[List[TensorType], GradInfoDict]]:
        """Performs a parallelized loss and gradient calculation over the batch.

        Splits up the given train batch into n shards (n=number of this
        Policy's devices) and passes each data shard (in parallel) through
        the loss function using the individual devices' models
        (self.model_gpu_towers). Then returns each tower's outputs.

        Args:
            sample_batches: A list of SampleBatch shards to
                calculate loss and gradients for.

        Returns:
            A list (one item per device) of 2-tuples, each with 1) gradient
            list and 2) grad info dict.
        """
        assert len(self.model_gpu_towers) == len(sample_batches)
        lock = threading.Lock()
        results = {}
        grad_enabled = torch.is_grad_enabled()

        def _worker(shard_idx, model, sample_batch, device):
            torch.set_grad_enabled(grad_enabled)
            try:
                with NullContextManager() if device.type == "cpu" else torch.cuda.device(  # noqa: E501
                    device
                ):
                    loss_out = force_list(
                        self._loss(self, model, self.dist_class, sample_batch)
                    )

                    # Call Model's custom-loss with Policy loss outputs and
                    # train_batch.
                    loss_out = model.custom_loss(loss_out, sample_batch)

                    assert len(loss_out) == len(self._optimizers)

                    # Loop through all optimizers.
                    grad_info = {"allreduce_latency": 0.0}

                    parameters = list(model.parameters())
                    all_grads = [None for _ in range(len(parameters))]
                    for opt_idx, opt in enumerate(self._optimizers):
                        # Erase gradients in all vars of the tower that this
                        # optimizer would affect.
                        param_indices = self.multi_gpu_param_groups[opt_idx]
                        for param_idx, param in enumerate(parameters):
                            if param_idx in param_indices and param.grad is not None:
                                param.grad.data.zero_()
                        # Recompute gradients of loss over all variables.
                        loss_out[opt_idx].backward(retain_graph=True)
                        grad_info.update(
                            self.extra_grad_process(opt, loss_out[opt_idx])
                        )

                        grads = []
                        # Note that return values are just references;
                        # Calling zero_grad would modify the values.
                        for param_idx, param in enumerate(parameters):
                            if param_idx in param_indices:
                                if param.grad is not None:
                                    grads.append(param.grad)
                                all_grads[param_idx] = param.grad

                        if self.distributed_world_size:
                            start = time.time()
                            if torch.cuda.is_available():
                                # Sadly, allreduce_coalesced does not work with
                                # CUDA yet.
                                for g in grads:
                                    torch.distributed.all_reduce(
                                        g, op=torch.distributed.ReduceOp.SUM
                                    )
                            else:
                                torch.distributed.all_reduce_coalesced(
                                    grads, op=torch.distributed.ReduceOp.SUM
                                )

                            for param_group in opt.param_groups:
                                for p in param_group["params"]:
                                    if p.grad is not None:
                                        p.grad /= self.distributed_world_size

                            grad_info["allreduce_latency"] += time.time() - start

                with lock:
                    results[shard_idx] = (all_grads, grad_info)
            except Exception as e:
                import traceback

                with lock:
                    results[shard_idx] = (
                        ValueError(
                            f"Error In tower {shard_idx} on device "
                            f"{device} during multi GPU parallel gradient "
                            f"calculation:"
                            f": {e}\n"
                            f"Traceback: \n"
                            f"{traceback.format_exc()}\n"
                        ),
                        e,
                    )

        # Single device (GPU) or fake-GPU case (serialize for better
        # debugging).
        if len(self.devices) == 1 or self.config["_fake_gpus"]:
            for shard_idx, (model, sample_batch, device) in enumerate(
                zip(self.model_gpu_towers, sample_batches, self.devices)
            ):
                _worker(shard_idx, model, sample_batch, device)
                # Raise errors right away for better debugging.
                last_result = results[len(results) - 1]
                if isinstance(last_result[0], ValueError):
                    raise last_result[0] from last_result[1]
        # Multi device (GPU) case: Parallelize via threads.
        else:
            threads = [
                threading.Thread(
                    target=_worker, args=(shard_idx, model, sample_batch, device)
                )
                for shard_idx, (model, sample_batch, device) in enumerate(
                    zip(self.model_gpu_towers, sample_batches, self.devices)
                )
            ]

            for thread in threads:
                thread.start()
            for thread in threads:
                thread.join()

        # Gather all threads' outputs and return.
        outputs = []
        for shard_idx in range(len(sample_batches)):
            output = results[shard_idx]
            if isinstance(output[0], Exception):
                raise output[0] from output[1]
            outputs.append(results[shard_idx])
        return outputs


@OldAPIStack
class DirectStepOptimizer:
    """Typesafe method for indicating `apply_gradients` can directly step the
    optimizers with in-place gradients.
    """

    _instance = None

    def __new__(cls):
        if DirectStepOptimizer._instance is None:
            DirectStepOptimizer._instance = super().__new__(cls)
        return DirectStepOptimizer._instance

    def __eq__(self, other):
        return type(self) is type(other)

    def __repr__(self):
        return "DirectStepOptimizer"


_directStepOptimizerSingleton = DirectStepOptimizer()
