"""Eager mode TF policy built using build_tf_policy().

It supports both traced and non-traced eager execution modes.
"""

import logging
import os
import threading
from typing import Dict, List, Optional, Tuple, Type, Union

import gymnasium as gym
import tree  # pip install dm_tree

from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
from ray.rllib.policy.eager_tf_policy import (
    _convert_to_tf,
    _disallow_var_creation,
    _OptimizerWrapper,
    _traced_eager_policy,
)
from ray.rllib.policy.policy import Policy, PolicyState
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import (
    OldAPIStack,
    OverrideToImplementCustomLogic,
    OverrideToImplementCustomLogic_CallToSuperRecommended,
    is_overridden,
    override,
)
from ray.rllib.utils.error import ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.metrics import (
    DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
    NUM_AGENT_STEPS_TRAINED,
    NUM_GRAD_UPDATES_LIFETIME,
)
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.spaces.space_utils import normalize_action
from ray.rllib.utils.tf_utils import get_gpu_devices
from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.typing import (
    AlgorithmConfigDict,
    LocalOptimizer,
    ModelGradients,
    TensorType,
)
from ray.util.debug import log_once

tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)


@OldAPIStack
class EagerTFPolicyV2(Policy):
    """A TF-eager / TF2 based tensorflow policy.

    This class is intended to be used and extended by sub-classing.
    """

    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        config: AlgorithmConfigDict,
        **kwargs,
    ):
        self.framework = config.get("framework", "tf2")

        # Log device.
        logger.info(
            "Creating TF-eager policy running on {}.".format(
                "GPU" if get_gpu_devices() else "CPU"
            )
        )

        Policy.__init__(self, observation_space, action_space, config)

        self._is_training = False
        # Global timestep should be a tensor.
        self.global_timestep = tf.Variable(0, trainable=False, dtype=tf.int64)
        self.explore = tf.Variable(
            self.config["explore"], trainable=False, dtype=tf.bool
        )

        # Log device and worker index.
        num_gpus = self._get_num_gpus_for_policy()
        if num_gpus > 0:
            gpu_ids = get_gpu_devices()
            logger.info(f"Found {len(gpu_ids)} visible cuda devices.")

        self._is_training = False

        self._loss_initialized = False
        # Backward compatibility workaround so Policy will call self.loss() directly.
        # TODO(jungong): clean up after all policies are migrated to new sub-class
        #  implementation.
        self._loss = None

        self.batch_divisibility_req = self.get_batch_divisibility_req()
        self._max_seq_len = self.config["model"]["max_seq_len"]

        self.validate_spaces(observation_space, action_space, self.config)

        # If using default make_model(), dist_class will get updated when
        # the model is created next.
        self.dist_class = self._init_dist_class()
        self.model = self.make_model()

        self._init_view_requirements()

        self.exploration = self._create_exploration()
        self._state_inputs = self.model.get_initial_state()
        self._is_recurrent = len(self._state_inputs) > 0

        # Got to reset global_timestep again after fake run-throughs.
        self.global_timestep.assign(0)

        # Lock used for locking some methods on the object-level.
        # This prevents possible race conditions when calling the model
        # first, then its value function (e.g. in a loss function), in
        # between of which another model call is made (e.g. to compute an
        # action).
        self._lock = threading.RLock()

        # Only for `config.eager_tracing=True`: A counter to keep track of
        # how many times an eager-traced method (e.g.
        # `self._compute_actions_helper`) has been re-traced by tensorflow.
        # We will raise an error if more than n re-tracings have been
        # detected, since this would considerably slow down execution.
        # The variable below should only get incremented during the
        # tf.function trace operations, never when calling the already
        # traced function after that.
        self._re_trace_counter = 0

    @staticmethod
    def enable_eager_execution_if_necessary():
        # If this class runs as a @ray.remote actor, eager mode may not
        # have been activated yet.
        if tf1 and not tf1.executing_eagerly():
            tf1.enable_eager_execution()

    @OverrideToImplementCustomLogic
    def validate_spaces(
        self,
        obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        config: AlgorithmConfigDict,
    ):
        return {}

    @OverrideToImplementCustomLogic
    @override(Policy)
    def loss(
        self,
        model: Union[ModelV2, "tf.keras.Model"],
        dist_class: Type[TFActionDistribution],
        train_batch: SampleBatch,
    ) -> Union[TensorType, List[TensorType]]:
        """Compute loss for this policy using model, dist_class and a train_batch.

        Args:
            model: The Model to calculate the loss for.
            dist_class: The action distr. class.
            train_batch: The training data.

        Returns:
            A single loss tensor or a list of loss tensors.
        """
        raise NotImplementedError

    @OverrideToImplementCustomLogic
    def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
        """Stats function. Returns a dict of statistics.

        Args:
            train_batch: The SampleBatch (already) used for training.

        Returns:
            The stats dict.
        """
        return {}

    @OverrideToImplementCustomLogic
    def grad_stats_fn(
        self, train_batch: SampleBatch, grads: ModelGradients
    ) -> Dict[str, TensorType]:
        """Gradient stats function. Returns a dict of statistics.

        Args:
            train_batch: The SampleBatch (already) used for training.

        Returns:
            The stats dict.
        """
        return {}

    @OverrideToImplementCustomLogic
    def make_model(self) -> ModelV2:
        """Build underlying model for this Policy.

        Returns:
            The Model for the Policy to use.
        """
        # Default ModelV2 model.
        _, logit_dim = ModelCatalog.get_action_dist(
            self.action_space, self.config["model"]
        )
        return ModelCatalog.get_model_v2(
            self.observation_space,
            self.action_space,
            logit_dim,
            self.config["model"],
            framework=self.framework,
        )

    @OverrideToImplementCustomLogic
    def compute_gradients_fn(
        self, policy: Policy, optimizer: LocalOptimizer, loss: TensorType
    ) -> ModelGradients:
        """Gradients computing function (from loss tensor, using local optimizer).

        Args:
            policy: The Policy object that generated the loss tensor and
                that holds the given local optimizer.
            optimizer: The tf (local) optimizer object to
                calculate the gradients with.
            loss: The loss tensor for which gradients should be
                calculated.

        Returns:
            ModelGradients: List of the possibly clipped gradients- and variable
                tuples.
        """
        return None

    @OverrideToImplementCustomLogic
    def apply_gradients_fn(
        self,
        optimizer: "tf.keras.optimizers.Optimizer",
        grads: ModelGradients,
    ) -> "tf.Operation":
        """Gradients computing function (from loss tensor, using local optimizer).

        Args:
            optimizer: The tf (local) optimizer object to
                calculate the gradients with.
            grads: The gradient tensor to be applied.

        Returns:
            "tf.Operation": TF operation that applies supplied gradients.
        """
        return None

    @OverrideToImplementCustomLogic
    def action_sampler_fn(
        self,
        model: ModelV2,
        *,
        obs_batch: TensorType,
        state_batches: TensorType,
        **kwargs,
    ) -> Tuple[TensorType, TensorType, TensorType, List[TensorType]]:
        """Custom function for sampling new actions given policy.

        Args:
            model: Underlying model.
            obs_batch: Observation tensor batch.
            state_batches: Action sampling state batch.

        Returns:
            Sampled action
            Log-likelihood
            Action distribution inputs
            Updated state
        """
        return None, None, None, None

    @OverrideToImplementCustomLogic
    def action_distribution_fn(
        self,
        model: ModelV2,
        *,
        obs_batch: TensorType,
        state_batches: TensorType,
        **kwargs,
    ) -> Tuple[TensorType, type, List[TensorType]]:
        """Action distribution function for this Policy.

        Args:
            model: Underlying model.
            obs_batch: Observation tensor batch.
            state_batches: Action sampling state batch.

        Returns:
            Distribution input.
            ActionDistribution class.
            State outs.
        """
        return None, None, None

    @OverrideToImplementCustomLogic
    def get_batch_divisibility_req(self) -> int:
        """Get batch divisibility request.

        Returns:
            Size N. A sample batch must be of size K*N.
        """
        # By default, any sized batch is ok, so simply return 1.
        return 1

    @OverrideToImplementCustomLogic_CallToSuperRecommended
    def extra_action_out_fn(self) -> Dict[str, TensorType]:
        """Extra values to fetch and return from compute_actions().

        Returns:
             Dict[str, TensorType]: An extra fetch-dict to be passed to and
                returned from the compute_actions() call.
        """
        return {}

    @OverrideToImplementCustomLogic_CallToSuperRecommended
    def extra_learn_fetches_fn(self) -> Dict[str, TensorType]:
        """Extra stats to be reported after gradient computation.

        Returns:
             Dict[str, TensorType]: An extra fetch-dict.
        """
        return {}

    @override(Policy)
    @OverrideToImplementCustomLogic_CallToSuperRecommended
    def postprocess_trajectory(
        self,
        sample_batch: SampleBatch,
        other_agent_batches: Optional[SampleBatch] = None,
        episode=None,
    ):
        """Post process trajectory in the format of a SampleBatch.

        Args:
            sample_batch: sample_batch: batch of experiences for the policy,
                which will contain at most one episode trajectory.
            other_agent_batches: In a multi-agent env, this contains a
                mapping of agent ids to (policy, agent_batch) tuples
                containing the policy and experiences of the other agents.
            episode: An optional multi-agent episode object to provide
                access to all of the internal episode state, which may
                be useful for model-based or multi-agent algorithms.

        Returns:
            The postprocessed sample batch.
        """
        assert tf.executing_eagerly()
        return Policy.postprocess_trajectory(self, sample_batch)

    @OverrideToImplementCustomLogic
    def optimizer(
        self,
    ) -> Union["tf.keras.optimizers.Optimizer", List["tf.keras.optimizers.Optimizer"]]:
        """TF optimizer to use for policy optimization.

        Returns:
            A local optimizer or a list of local optimizers to use for this
                Policy's Model.
        """
        return tf.keras.optimizers.Adam(self.config["lr"])

    def _init_dist_class(self):
        if is_overridden(self.action_sampler_fn) or is_overridden(
            self.action_distribution_fn
        ):
            if not is_overridden(self.make_model):
                raise ValueError(
                    "`make_model` is required if `action_sampler_fn` OR "
                    "`action_distribution_fn` is given"
                )
            return None
        else:
            dist_class, _ = ModelCatalog.get_action_dist(
                self.action_space, self.config["model"]
            )
            return dist_class

    def _init_view_requirements(self):
        # Auto-update model's inference view requirements, if recurrent.
        self._update_model_view_requirements_from_init_state()
        # Combine view_requirements for Model and Policy.
        self.view_requirements.update(self.model.view_requirements)

        # Disable env-info placeholder.
        if SampleBatch.INFOS in self.view_requirements:
            self.view_requirements[SampleBatch.INFOS].used_for_training = False

    def maybe_initialize_optimizer_and_loss(self):
        optimizers = force_list(self.optimizer())
        if self.exploration:
            # Policies with RLModules don't have an exploration object.
            optimizers = self.exploration.get_exploration_optimizer(optimizers)

        # The list of local (tf) optimizers (one per loss term).
        self._optimizers: List[LocalOptimizer] = optimizers
        # Backward compatibility: A user's policy may only support a single
        # loss term and optimizer (no lists).
        self._optimizer: LocalOptimizer = optimizers[0] if optimizers else None

        self._initialize_loss_from_dummy_batch(
            auto_remove_unneeded_view_reqs=True,
        )
        self._loss_initialized = True

    @override(Policy)
    def compute_actions_from_input_dict(
        self,
        input_dict: Dict[str, TensorType],
        explore: bool = None,
        timestep: Optional[int] = None,
        episodes=None,
        **kwargs,
    ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
        self._is_training = False

        explore = explore if explore is not None else self.explore
        timestep = timestep if timestep is not None else self.global_timestep
        if isinstance(timestep, tf.Tensor):
            timestep = int(timestep.numpy())

        # Pass lazy (eager) tensor dict to Model as `input_dict`.
        input_dict = self._lazy_tensor_dict(input_dict)
        input_dict.set_training(False)

        # Pack internal state inputs into (separate) list.
        state_batches = [
            input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
        ]
        self._state_in = state_batches
        self._is_recurrent = len(tree.flatten(self._state_in)) > 0

        # Call the exploration before_compute_actions hook.
        if self.exploration:
            # Policies with RLModules don't have an exploration object.
            self.exploration.before_compute_actions(
                timestep=timestep, explore=explore, tf_sess=self.get_session()
            )

        ret = self._compute_actions_helper(
            input_dict,
            state_batches,
            # TODO: Passing episodes into a traced method does not work.
            None if self.config["eager_tracing"] else episodes,
            explore,
            timestep,
        )
        # Update our global timestep by the batch size.
        self.global_timestep.assign_add(tree.flatten(ret[0])[0].shape.as_list()[0])
        return convert_to_numpy(ret)

    # TODO(jungong) : deprecate this API and make compute_actions_from_input_dict the
    # only canonical entry point for inference.
    @override(Policy)
    def compute_actions(
        self,
        obs_batch,
        state_batches=None,
        prev_action_batch=None,
        prev_reward_batch=None,
        info_batch=None,
        episodes=None,
        explore=None,
        timestep=None,
        **kwargs,
    ):
        # Create input dict to simply pass the entire call to
        # self.compute_actions_from_input_dict().
        input_dict = SampleBatch(
            {
                SampleBatch.CUR_OBS: obs_batch,
            },
            _is_training=tf.constant(False),
        )
        if state_batches is not None:
            for s in enumerate(state_batches):
                input_dict["state_in_{i}"] = s
        if prev_action_batch is not None:
            input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
        if prev_reward_batch is not None:
            input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
        if info_batch is not None:
            input_dict[SampleBatch.INFOS] = info_batch

        return self.compute_actions_from_input_dict(
            input_dict=input_dict,
            explore=explore,
            timestep=timestep,
            episodes=episodes,
            **kwargs,
        )

    @with_lock
    @override(Policy)
    def compute_log_likelihoods(
        self,
        actions: Union[List[TensorType], TensorType],
        obs_batch: Union[List[TensorType], TensorType],
        state_batches: Optional[List[TensorType]] = None,
        prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None,
        prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None,
        actions_normalized: bool = True,
        in_training: bool = True,
    ) -> TensorType:
        if is_overridden(self.action_sampler_fn) and not is_overridden(
            self.action_distribution_fn
        ):
            raise ValueError(
                "Cannot compute log-prob/likelihood w/o an "
                "`action_distribution_fn` and a provided "
                "`action_sampler_fn`!"
            )

        seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
        input_batch = SampleBatch(
            {
                SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch),
                SampleBatch.ACTIONS: actions,
            },
            _is_training=False,
        )
        if prev_action_batch is not None:
            input_batch[SampleBatch.PREV_ACTIONS] = tf.convert_to_tensor(
                prev_action_batch
            )
        if prev_reward_batch is not None:
            input_batch[SampleBatch.PREV_REWARDS] = tf.convert_to_tensor(
                prev_reward_batch
            )

        # Exploration hook before each forward pass.
        if self.exploration:
            # Policies with RLModules don't have an exploration object.
            self.exploration.before_compute_actions(explore=False)

        # Action dist class and inputs are generated via custom function.
        if is_overridden(self.action_distribution_fn):
            dist_inputs, self.dist_class, _ = self.action_distribution_fn(
                self, self.model, input_batch, explore=False, is_training=False
            )
            action_dist = self.dist_class(dist_inputs, self.model)
        # Default log-likelihood calculation.
        else:
            dist_inputs, _ = self.model(input_batch, state_batches, seq_lens)
            action_dist = self.dist_class(dist_inputs, self.model)

        # Normalize actions if necessary.
        if not actions_normalized and self.config["normalize_actions"]:
            actions = normalize_action(actions, self.action_space_struct)

        log_likelihoods = action_dist.logp(actions)

        return log_likelihoods

    @with_lock
    @override(Policy)
    def learn_on_batch(self, postprocessed_batch):
        # Callback handling.
        learn_stats = {}
        self.callbacks.on_learn_on_batch(
            policy=self, train_batch=postprocessed_batch, result=learn_stats
        )

        pad_batch_to_sequences_of_same_size(
            postprocessed_batch,
            max_seq_len=self._max_seq_len,
            shuffle=False,
            batch_divisibility_req=self.batch_divisibility_req,
            view_requirements=self.view_requirements,
        )

        self._is_training = True
        postprocessed_batch = self._lazy_tensor_dict(postprocessed_batch)
        postprocessed_batch.set_training(True)
        stats = self._learn_on_batch_helper(postprocessed_batch)
        self.num_grad_updates += 1

        stats.update(
            {
                "custom_metrics": learn_stats,
                NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
                NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
                # -1, b/c we have to measure this diff before we do the update above.
                DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
                    self.num_grad_updates
                    - 1
                    - (postprocessed_batch.num_grad_updates or 0)
                ),
            }
        )

        return convert_to_numpy(stats)

    @override(Policy)
    def compute_gradients(
        self, postprocessed_batch: SampleBatch
    ) -> Tuple[ModelGradients, Dict[str, TensorType]]:

        pad_batch_to_sequences_of_same_size(
            postprocessed_batch,
            shuffle=False,
            max_seq_len=self._max_seq_len,
            batch_divisibility_req=self.batch_divisibility_req,
            view_requirements=self.view_requirements,
        )

        self._is_training = True
        self._lazy_tensor_dict(postprocessed_batch)
        postprocessed_batch.set_training(True)
        grads_and_vars, grads, stats = self._compute_gradients_helper(
            postprocessed_batch
        )
        return convert_to_numpy((grads, stats))

    @override(Policy)
    def apply_gradients(self, gradients: ModelGradients) -> None:
        self._apply_gradients_helper(
            list(
                zip(
                    [
                        (tf.convert_to_tensor(g) if g is not None else None)
                        for g in gradients
                    ],
                    self.model.trainable_variables(),
                )
            )
        )

    @override(Policy)
    def get_weights(self, as_dict=False):
        variables = self.variables()
        if as_dict:
            return {v.name: v.numpy() for v in variables}
        return [v.numpy() for v in variables]

    @override(Policy)
    def set_weights(self, weights):
        variables = self.variables()
        assert len(weights) == len(variables), (len(weights), len(variables))
        for v, w in zip(variables, weights):
            v.assign(w)

    @override(Policy)
    def get_exploration_state(self):
        return convert_to_numpy(self.exploration.get_state())

    @override(Policy)
    def is_recurrent(self):
        return self._is_recurrent

    @override(Policy)
    def num_state_tensors(self):
        return len(self._state_inputs)

    @override(Policy)
    def get_initial_state(self):
        if hasattr(self, "model"):
            return self.model.get_initial_state()
        return []

    @override(Policy)
    @OverrideToImplementCustomLogic_CallToSuperRecommended
    def get_state(self) -> PolicyState:
        # Legacy Policy state (w/o keras model and w/o PolicySpec).
        state = super().get_state()

        state["global_timestep"] = state["global_timestep"].numpy()
        # In the new Learner API stack, the optimizers live in the learner.
        state["_optimizer_variables"] = []
        if self._optimizer and len(self._optimizer.variables()) > 0:
            state["_optimizer_variables"] = self._optimizer.variables()

        # Add exploration state.
        if self.exploration:
            # This is not compatible with RLModules, which have a method
            # `forward_exploration` to specify custom exploration behavior.
            state["_exploration_state"] = self.exploration.get_state()

        return state

    @override(Policy)
    @OverrideToImplementCustomLogic_CallToSuperRecommended
    def set_state(self, state: PolicyState) -> None:
        # Set optimizer vars.
        optimizer_vars = state.get("_optimizer_variables", None)
        if optimizer_vars and self._optimizer.variables():
            if not type(self).__name__.endswith("_traced") and log_once(
                "set_state_optimizer_vars_tf_eager_policy_v2"
            ):
                logger.warning(
                    "Cannot restore an optimizer's state for tf eager! Keras "
                    "is not able to save the v1.x optimizers (from "
                    "tf.compat.v1.train) since they aren't compatible with "
                    "checkpoints."
                )
            for opt_var, value in zip(self._optimizer.variables(), optimizer_vars):
                opt_var.assign(value)
        # Set exploration's state.
        if hasattr(self, "exploration") and "_exploration_state" in state:
            self.exploration.set_state(state=state["_exploration_state"])

        # Restore glbal timestep (tf vars).
        self.global_timestep.assign(state["global_timestep"])

        # Then the Policy's (NN) weights and connectors.
        super().set_state(state)

    @override(Policy)
    def export_model(self, export_dir, onnx: Optional[int] = None) -> None:
        if onnx:
            try:
                import tf2onnx
            except ImportError as e:
                raise RuntimeError(
                    "Converting a TensorFlow model to ONNX requires "
                    "`tf2onnx` to be installed. Install with "
                    "`pip install tf2onnx`."
                ) from e

            model_proto, external_tensor_storage = tf2onnx.convert.from_keras(
                self.model.base_model,
                output_path=os.path.join(export_dir, "model.onnx"),
            )
        # Save the tf.keras.Model (architecture and weights, so it can be retrieved
        # w/o access to the original (custom) Model or Policy code).
        elif (
            hasattr(self, "model")
            and hasattr(self.model, "base_model")
            and isinstance(self.model.base_model, tf.keras.Model)
        ):
            try:
                self.model.base_model.save(export_dir, save_format="tf")
            except Exception:
                logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
        else:
            logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)

    def variables(self):
        """Return the list of all savable variables for this policy."""
        if isinstance(self.model, tf.keras.Model):
            return self.model.variables
        else:
            return self.model.variables()

    def loss_initialized(self):
        return self._loss_initialized

    @with_lock
    def _compute_actions_helper(
        self,
        input_dict,
        state_batches,
        episodes,
        explore,
        timestep,
        _ray_trace_ctx=None,
    ):
        # Increase the tracing counter to make sure we don't re-trace too
        # often. If eager_tracing=True, this counter should only get
        # incremented during the @tf.function trace operations, never when
        # calling the already traced function after that.
        self._re_trace_counter += 1

        # Calculate RNN sequence lengths.
        if SampleBatch.SEQ_LENS in input_dict:
            seq_lens = input_dict[SampleBatch.SEQ_LENS]
        else:
            batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]
            seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches else None

        # Add default and custom fetches.
        extra_fetches = {}

        with tf.variable_creator_scope(_disallow_var_creation):

            if is_overridden(self.action_sampler_fn):
                actions, logp, dist_inputs, state_out = self.action_sampler_fn(
                    self.model,
                    obs_batch=input_dict[SampleBatch.OBS],
                    state_batches=state_batches,
                    seq_lens=seq_lens,
                    explore=explore,
                    timestep=timestep,
                    episodes=episodes,
                )
            else:
                # Try `action_distribution_fn`.
                if is_overridden(self.action_distribution_fn):
                    (
                        dist_inputs,
                        self.dist_class,
                        state_out,
                    ) = self.action_distribution_fn(
                        self.model,
                        obs_batch=input_dict[SampleBatch.OBS],
                        state_batches=state_batches,
                        seq_lens=seq_lens,
                        explore=explore,
                        timestep=timestep,
                        is_training=False,
                    )
                elif isinstance(self.model, tf.keras.Model):
                    if state_batches and "state_in_0" not in input_dict:
                        for i, s in enumerate(state_batches):
                            input_dict[f"state_in_{i}"] = s
                    self._lazy_tensor_dict(input_dict)
                    dist_inputs, state_out, extra_fetches = self.model(input_dict)
                else:
                    dist_inputs, state_out = self.model(
                        input_dict, state_batches, seq_lens
                    )

                action_dist = self.dist_class(dist_inputs, self.model)

                # Get the exploration action from the forward results.
                actions, logp = self.exploration.get_exploration_action(
                    action_distribution=action_dist,
                    timestep=timestep,
                    explore=explore,
                )

        # Action-logp and action-prob.
        if logp is not None:
            extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp)
            extra_fetches[SampleBatch.ACTION_LOGP] = logp
        # Action-dist inputs.
        if dist_inputs is not None:
            extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
        # Custom extra fetches.
        extra_fetches.update(self.extra_action_out_fn())

        return actions, state_out, extra_fetches

    # TODO: Figure out, why _ray_trace_ctx=None helps to prevent a crash in
    #  AlphaStar w/ framework=tf2; eager_tracing=True on the policy learner actors.
    #  It seems there may be a clash between the traced-by-tf function and the
    #  traced-by-ray functions (for making the policy class a ray actor).
    def _learn_on_batch_helper(self, samples, _ray_trace_ctx=None):
        # Increase the tracing counter to make sure we don't re-trace too
        # often. If eager_tracing=True, this counter should only get
        # incremented during the @tf.function trace operations, never when
        # calling the already traced function after that.
        self._re_trace_counter += 1

        with tf.variable_creator_scope(_disallow_var_creation):
            grads_and_vars, _, stats = self._compute_gradients_helper(samples)
        self._apply_gradients_helper(grads_and_vars)
        return stats

    def _get_is_training_placeholder(self):
        return tf.convert_to_tensor(self._is_training)

    @with_lock
    def _compute_gradients_helper(self, samples):
        """Computes and returns grads as eager tensors."""

        # Increase the tracing counter to make sure we don't re-trace too
        # often. If eager_tracing=True, this counter should only get
        # incremented during the @tf.function trace operations, never when
        # calling the already traced function after that.
        self._re_trace_counter += 1

        # Gather all variables for which to calculate losses.
        if isinstance(self.model, tf.keras.Model):
            variables = self.model.trainable_variables
        else:
            variables = self.model.trainable_variables()

        # Calculate the loss(es) inside a tf GradientTape.
        with tf.GradientTape(
            persistent=is_overridden(self.compute_gradients_fn)
        ) as tape:
            losses = self.loss(self.model, self.dist_class, samples)
        losses = force_list(losses)

        # User provided a custom compute_gradients_fn.
        if is_overridden(self.compute_gradients_fn):
            # Wrap our tape inside a wrapper, such that the resulting
            # object looks like a "classic" tf.optimizer. This way, custom
            # compute_gradients_fn will work on both tf static graph
            # and tf-eager.
            optimizer = _OptimizerWrapper(tape)
            # More than one loss terms/optimizers.
            if self.config["_tf_policy_handles_more_than_one_loss"]:
                grads_and_vars = self.compute_gradients_fn(
                    [optimizer] * len(losses), losses
                )
            # Only one loss and one optimizer.
            else:
                grads_and_vars = [self.compute_gradients_fn(optimizer, losses[0])]
        # Default: Compute gradients using the above tape.
        else:
            grads_and_vars = [
                list(zip(tape.gradient(loss, variables), variables)) for loss in losses
            ]

        if log_once("grad_vars"):
            for g_and_v in grads_and_vars:
                for g, v in g_and_v:
                    if g is not None:
                        logger.info(f"Optimizing variable {v.name}")

        # `grads_and_vars` is returned a list (len=num optimizers/losses)
        # of lists of (grad, var) tuples.
        if self.config["_tf_policy_handles_more_than_one_loss"]:
            grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars]
        # `grads_and_vars` is returned as a list of (grad, var) tuples.
        else:
            grads_and_vars = grads_and_vars[0]
            grads = [g for g, _ in grads_and_vars]

        stats = self._stats(samples, grads)
        return grads_and_vars, grads, stats

    def _apply_gradients_helper(self, grads_and_vars):
        # Increase the tracing counter to make sure we don't re-trace too
        # often. If eager_tracing=True, this counter should only get
        # incremented during the @tf.function trace operations, never when
        # calling the already traced function after that.
        self._re_trace_counter += 1

        if is_overridden(self.apply_gradients_fn):
            if self.config["_tf_policy_handles_more_than_one_loss"]:
                self.apply_gradients_fn(self._optimizers, grads_and_vars)
            else:
                self.apply_gradients_fn(self._optimizer, grads_and_vars)
        else:
            if self.config["_tf_policy_handles_more_than_one_loss"]:
                for i, o in enumerate(self._optimizers):
                    o.apply_gradients(
                        [(g, v) for g, v in grads_and_vars[i] if g is not None]
                    )
            else:
                self._optimizer.apply_gradients(
                    [(g, v) for g, v in grads_and_vars if g is not None]
                )

    def _stats(self, samples, grads):
        fetches = {}
        if is_overridden(self.stats_fn):
            fetches[LEARNER_STATS_KEY] = dict(self.stats_fn(samples))
        else:
            fetches[LEARNER_STATS_KEY] = {}

        fetches.update(dict(self.extra_learn_fetches_fn()))
        fetches.update(dict(self.grad_stats_fn(samples, grads)))
        return fetches

    def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch):
        # TODO: (sven): Keep for a while to ensure backward compatibility.
        if not isinstance(postprocessed_batch, SampleBatch):
            postprocessed_batch = SampleBatch(postprocessed_batch)
        postprocessed_batch.set_get_interceptor(_convert_to_tf)
        return postprocessed_batch

    @classmethod
    def with_tracing(cls):
        return _traced_eager_policy(cls)
