import functools
import logging
import os
import platform
import queue
import sys
import threading
import time
import warnings
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type

import ray
from ray.air._internal.util import RunnerThread, StartTraceback
from ray.air.constants import (
    _ERROR_FETCH_TIMEOUT,
    _RESULT_FETCH_TIMEOUT,
    SESSION_MISUSE_LOG_ONCE_KEY,
    TIME_THIS_ITER_S,
    TIMESTAMP,
)
from ray.data import Dataset
from ray.train import Checkpoint
from ray.train._internal.accelerator import Accelerator
from ray.train._internal.storage import StorageContext
from ray.train.constants import (
    CHECKPOINT_DIR_NAME,
    DETAILED_AUTOFILLED_KEYS,
    RAY_CHDIR_TO_TRIAL_DIR,
    TIME_TOTAL_S,
    WORKER_HOSTNAME,
    WORKER_NODE_IP,
    WORKER_PID,
    _v2_migration_warnings_enabled,
)
from ray.train.error import SessionMisuseError
from ray.train.utils import _log_deprecation_warning
from ray.util import queue as ray_queue
from ray.util.annotations import DeveloperAPI, PublicAPI
from ray.util.debug import log_once
from ray.util.placement_group import _valid_resource_shape
from ray.util.scheduling_strategies import (
    PlacementGroupSchedulingStrategy,
    SchedulingStrategyT,
)

if TYPE_CHECKING:
    from ray.data import DataIterator
    from ray.tune.execution.placement_groups import PlacementGroupFactory


logger = logging.getLogger(__name__)


@dataclass
class TrialInfo:
    """The trial information to propagate to TrainSession."""

    name: str
    id: str
    resources: Dict[str, float]
    logdir: str
    driver_ip: str
    driver_node_id: str
    experiment_name: Optional[str] = None
    run_id: Optional[str] = None


class _FutureTrainingResult:
    """A future that will be resolved to a `_TrainingResult`.

    This is needed for specific schedulers such as PBT that schedule saves.

    This wrapper should be removed after refactoring PBT to not schedule saves anymore.
    """

    def __init__(self, future: ray.ObjectRef):
        self.future = future

    def resolve(self, block: bool = True) -> Optional["_TrainingResult"]:
        """Resolve into ``_TrainingResult``.

        This will return None for function trainables if no checkpoint has been
        saved before.
        """
        if block:
            timeout = None
        else:
            timeout = 1e-9
        try:
            return ray.get(self.future, timeout=timeout)
        except TimeoutError:
            # Not ready, yet
            pass
        except Exception as exc:
            logger.error(f"Error resolving result: {exc}")


class _TrainingResult:
    """A (checkpoint, metrics) result reported by the user."""

    def __init__(self, checkpoint: Optional[Checkpoint], metrics: Dict[str, Any]):
        self.checkpoint = checkpoint
        self.metrics = metrics

    def __repr__(self) -> str:
        return f"TrainingResult(checkpoint={self.checkpoint}, metrics={self.metrics})"


# TODO(xwjiang): This needs a better name.
@DeveloperAPI
class _TrainSession:
    """Holds information for training on each worker."""

    def __init__(
        self,
        training_func: Callable,
        world_rank: Optional[int],
        local_rank: Optional[int],
        node_rank: Optional[int],
        local_world_size: Optional[int],
        world_size: Optional[int],
        trial_info: Optional[TrialInfo] = None,
        dataset_shard: Optional[Dict[str, Dataset]] = None,
        metadata: Dict[str, Any] = None,
        checkpoint: Optional[Checkpoint] = None,
        detailed_autofilled_metrics: bool = False,
        storage: Optional[StorageContext] = None,
        synchronous_result_reporting: bool = False,
    ):
        # `synchronous_result_reporting` refers to whether or not the
        # training function is immediately unblocked to continue running
        # after the main thread receives its result.
        # Ex 1: For 2 Ray Train workers with synchronous_result_reporting=True,
        # the worker that produces a result first will immediately will continue
        # onto the next iteration.
        # Ex 2: For a Tune function Trainable with `synchronous_result_reporting=False`,
        # training will only continue with an explicit call to `session.get_next`.
        # Synchronous reporting in example 2 is needed for Tune schedulers to
        # be able to stop the execution of the training function at will,
        # for advanced pausing schedulers (PBT, BOHB) and actor reuse.
        self.synchronous_result_reporting = synchronous_result_reporting

        # Ray Train worker properties
        # Note: These are set to None for Tune function Trainables.
        self.dataset_shard = dataset_shard
        self.metadata = metadata

        self.world_rank = world_rank
        self.local_rank = local_rank
        self.node_rank = node_rank
        self.local_world_size = local_world_size
        self.world_size = world_size

        assert storage
        logger.debug(f"StorageContext on SESSION (rank={world_rank}):\n{storage}")

        # NOTE: `reset` will initialize many properties needed to start running the
        # training_func as a thread.
        self.reset(
            training_func=training_func,
            trial_info=trial_info,
            storage=storage,
            loaded_checkpoint=checkpoint,
        )

        # Autofilled metrics attributes.
        self.detailed_autofilled_metrics = detailed_autofilled_metrics
        self.last_report_time = time.time()
        self.iteration = 0
        self.time_total = 0.0
        self.local_ip = self.get_current_ip()

        self.accelerator = None
        self._state = {}

    def get_state(self, key: str) -> Any:
        return self._state.get(key)

    def set_state(self, key: str, value: Any):
        self._state[key] = value

    def get_current_ip(self):
        self.local_ip = ray.util.get_node_ip_address()
        return self.local_ip

    def start(self):
        """Starts the training thread."""
        self.training_started = True
        self.training_thread.start()

    def reset(
        self,
        training_func: Callable,
        trial_info: TrialInfo,
        storage: StorageContext,
        loaded_checkpoint=None,
    ):
        # This lock is used to control the execution of the training thread.
        self.continue_lock = threading.Semaphore(0)

        # This event is used to signal the training thread to stop.
        self.stop_event = threading.Event()

        # Queue for sending results across threads.
        self.result_queue = queue.Queue(1)

        # Queue for sending results from training actor to main thread.
        self._inter_actor_queue: Optional[ray_queue.Queue[Dict]] = None

        # Queue for raising exceptions from runner thread to main thread.
        # The error queue has a max size of one to prevent stacking error and force
        # error reporting to block until finished.
        self.error_queue = queue.Queue(1)

        # The Thread object that is running the training function.
        self.training_thread = RunnerThread(
            target=training_func, daemon=True, error_queue=self.error_queue
        )

        # Possibly override with new state
        self.trial_info = trial_info
        self.storage = storage
        self.loaded_checkpoint = loaded_checkpoint

        # Reset state
        self._state = {}
        self.ignore_report = False
        self.training_started = False
        self._first_report = True

        # Change the working directory to a special trial folder.
        # This is to ensure that all Ray Train workers have a common working directory.
        os.makedirs(storage.trial_working_directory, exist_ok=True)
        if bool(int(os.environ.get(RAY_CHDIR_TO_TRIAL_DIR, "1"))):
            logger.debug(
                f"Changing the working directory to: {storage.trial_working_directory}"
            )
            os.chdir(storage.trial_working_directory)

    def pause_reporting(self):
        """Ignore all future ``session.report()`` calls."""
        self.ignore_report = True

    def finish(self, timeout: Optional[float] = None) -> Optional[Any]:
        """Finishes the training thread.

        Raises any Exception from training.
        """
        # Set the stop event for the training thread to gracefully exit.
        self.stop_event.set()

        # Release the lock so that training thread can process this event.
        self.continue_lock.release()

        # Force a final (blocking) sync of artifacts in the trial path to storage.
        self.storage.persist_artifacts(force=True)

        # Wait for training to finish.
        # This will raise any errors that occur during training, including SystemError
        # This returns the result of the training function.
        output = None
        if self.training_started:
            output = self.training_thread.join(timeout=timeout)

        return output

    def get_next(self) -> Optional[_TrainingResult]:
        """Gets the next ``_TrainingResult`` from the result queue.

        If the result queue is empty, then this function returns ``None``.
        """
        if not self.training_started:
            raise RuntimeError("Please call start before calling get_next.")

        if self.synchronous_result_reporting:
            # There's no need to release the lock on the first report
            # since `start` already started the training thread.
            if not self._first_report:
                # Release the lock to trigger training to continue,
                # until the next call to report.
                self.continue_lock.release()
            self._first_report = False

        result = None
        # While training is still ongoing, attempt to get the result.
        while result is None and self.training_thread.is_alive():
            result = self._get_result_from_queues(block=True)

        # If no result was found, then the runner must no longer be alive.
        if result is None:
            # Try one last time to fetch results in case results were
            # reported in between the time of the last check and the
            # termination of the thread runner.
            result = self._get_result_from_queues(block=False)

        # check if error occurred inside the thread runner.
        if result is None:
            # only raise an error from the runner if all results are consumed
            self._report_thread_runner_error(block=True)
        else:
            if not self.error_queue.empty():
                logger.debug(
                    (
                        "Runner error waiting to be raised in main thread. "
                        "Logging all available results first."
                    )
                )

        if not self.synchronous_result_reporting:
            # At this point, the training thread has reached
            # the `train.report` and is blocked there.
            # If performing asynchronous result reporting,
            # release the lock to allow each worker to keep training
            # immediately after the coordinator fetches their result.
            self.continue_lock.release()

        # Return None if there are no more results to fetch.
        return result

    def _get_or_create_inter_actor_queue(self):
        """Get or create the inter-actor queue."""
        if self._inter_actor_queue is None:
            self._inter_actor_queue = ray_queue.Queue(1, actor_options={"num_cpus": 0})
        return self._inter_actor_queue

    def _get_result_from_queues(self, block: bool) -> Optional[_TrainingResult]:
        """Get result from result queue. Pass result from training actor result queue if needed."""
        result = None
        if self._inter_actor_queue is not None:
            try:
                inter_actor_item = self._inter_actor_queue.get(
                    block=block, timeout=_RESULT_FETCH_TIMEOUT
                )
                if inter_actor_item:
                    # Must release continue_lock to allow report to work.
                    self.continue_lock.release()
                    self.report(inter_actor_item)
            except ray_queue.Empty:
                pass
        try:
            result = self.result_queue.get(block=block, timeout=_RESULT_FETCH_TIMEOUT)
        except queue.Empty:
            pass
        return result

    def _auto_fill_metrics(self, result: dict) -> dict:
        """Add autofilled metrics and update attributes."""
        current_time = time.time()
        current_datetime = datetime.now()
        if TIME_THIS_ITER_S in result:
            time_this_iter = result[TIME_THIS_ITER_S]
        else:
            time_this_iter = current_time - self.last_report_time
        self.iteration += 1
        self.time_total += time_this_iter
        self.last_report_time = current_time

        auto_filled_metrics = {
            TIMESTAMP: int(time.mktime(current_datetime.timetuple())),
            TIME_TOTAL_S: self.time_total,
            WORKER_PID: os.getpid(),
            WORKER_HOSTNAME: platform.node(),
            WORKER_NODE_IP: self.local_ip,
        }

        if not self.detailed_autofilled_metrics:
            auto_filled_metrics = {
                k: v
                for k, v in auto_filled_metrics.items()
                if k not in DETAILED_AUTOFILLED_KEYS
            }

        result = result.copy()
        result.update(auto_filled_metrics)
        return result

    def _auto_fill_checkpoint_metrics(self, result: dict) -> dict:
        """Add autofilled metrics and update attributes."""
        current_datetime = datetime.now()

        auto_filled_metrics = {
            TIMESTAMP: int(time.mktime(current_datetime.timetuple()))
        }
        result = result.copy()
        result.update(auto_filled_metrics)
        return result

    def _report_thread_runner_error(self, block=False):
        try:
            e = self.error_queue.get(block=block, timeout=_ERROR_FETCH_TIMEOUT)
            raise StartTraceback from e
        except queue.Empty:
            pass

    def _report_training_result(self, training_result: _TrainingResult) -> None:
        """Place a training result on the result queue for the main thread to process,
        then block until the main thread signals that training should continue.

        NOTE: This is used internally to report results from Train to Tune
        without persisting checkpoints to storage 2 times.
        `report` is the public API that directly persists to storage, which
        should only be called by user code.
        """
        if training_result.checkpoint:
            # NOTE: This populates `train.get_checkpoint`
            self.loaded_checkpoint = training_result.checkpoint

        # Add result to a thread-safe queue.
        self.result_queue.put(training_result, block=True)

        # Acquire lock to stop the training thread until main thread
        # triggers resume.
        self.continue_lock.acquire()

        # If the trial should be terminated, exit gracefully.
        # NOTE: This is only really useful if `synchronous_result_reporting=True`.
        # Otherwise, the lock is immediately released on reporting, and this
        # check is skipped before the main thread decides to set the stop event.
        if self.stop_event.is_set():
            self.stop_event.clear()
            sys.exit(0)

    def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None:
        # Special case: early fail for Torch tensors
        if "torch" in sys.modules:
            from ray.air._internal.torch_utils import contains_tensor

            if contains_tensor(metrics):
                raise ValueError(
                    "Passing objects containg Torch tensors as metrics "
                    "is not supported as it will throw an exception on "
                    "deserialization. You can either convert the tensors "
                    "to Python objects or report a `train.Checkpoint` "
                    "with `ray.train.report` to store your Torch objects."
                )

        if self.ignore_report:
            return

        metrics = self._auto_fill_metrics(metrics)

        persisted_checkpoint = None
        if checkpoint:
            self.storage._update_checkpoint_index(metrics)

            # Persist the reported checkpoint files to storage.
            persisted_checkpoint = self.storage.persist_current_checkpoint(checkpoint)

            metrics[CHECKPOINT_DIR_NAME] = self.storage.checkpoint_dir_name
        else:
            metrics[CHECKPOINT_DIR_NAME] = None

        # Persist trial artifacts to storage.
        force_artifact_sync = (
            persisted_checkpoint
            and self.storage.sync_config.sync_artifacts_on_checkpoint
        )
        self.storage.persist_artifacts(force=force_artifact_sync)

        # Set additional user metadata from the Trainer.
        if persisted_checkpoint and self.metadata:
            user_metadata = persisted_checkpoint.get_metadata()
            for k, v in self.metadata.items():
                # Update keys not already set by the user. This gives user-set keys
                # precedence over keys set at the Trainer level.
                if k not in user_metadata:
                    user_metadata[k] = v
            persisted_checkpoint.set_metadata(user_metadata)

        result = _TrainingResult(checkpoint=persisted_checkpoint, metrics=metrics)

        self._report_training_result(result)

    @property
    def experiment_name(self) -> str:
        return self.trial_info.experiment_name

    @property
    def trial_name(self) -> str:
        return self.trial_info.name

    @property
    def trial_id(self) -> str:
        return self.trial_info.id

    @property
    def run_id(self) -> str:
        return self.trial_info.run_id

    @property
    def trial_resources(self) -> "PlacementGroupFactory":
        return self.trial_info.resources

    @property
    def trial_dir(self) -> str:
        return self.trial_info.logdir

    def get_dataset_shard(
        self,
        dataset_name: Optional[str] = None,
    ) -> Optional["DataIterator"]:
        shard = self.dataset_shard
        if shard is None:
            warnings.warn(
                "No dataset passed in. Returning None. Make sure to "
                "pass in a Dataset to Trainer.run to use this "
                "function."
            )
        elif isinstance(shard, dict):
            if not dataset_name:
                raise RuntimeError(
                    "Multiple datasets were passed into ``Trainer``, "
                    "but no ``dataset_name`` is passed into "
                    "``get_dataset_shard``. Please specify which "
                    "dataset shard to retrieve."
                )
            return shard.get(dataset_name)
        return shard


# Cache of resource dicts that have been checked by the launch hook already.
_checked_resources: Set[frozenset] = set()

# Global _TrainSession object initialized by Ray Tune function trainables
# and Ray Train V1 workers.
_session: Optional[_TrainSession] = None


def _tune_task_and_actor_launch_hook(
    fn, resources: Dict[str, float], strategy: Optional[SchedulingStrategyT]
):
    """Launch hook to catch nested tasks that can't fit in the placement group.

    This gives users a nice warning in case they launch a nested task in a Tune trial
    without reserving resources in the trial placement group to fit it.
    """

    # Already checked, skip for performance reasons.
    key = frozenset({(k, v) for k, v in resources.items() if v > 0})
    if not key or key in _checked_resources:
        return

    # No need to check if placement group is None.
    if (
        not isinstance(strategy, PlacementGroupSchedulingStrategy)
        or strategy.placement_group is None
    ):
        return

    # Check if the resource request is targeting the current placement group.
    cur_pg = ray.util.get_current_placement_group()
    if not cur_pg or strategy.placement_group.id != cur_pg.id:
        return

    _checked_resources.add(key)

    # Check if the request can be fulfilled by the current placement group.
    pgf = get_trial_resources()

    if pgf.head_bundle_is_empty:
        available_bundles = cur_pg.bundle_specs[0:]
    else:
        available_bundles = cur_pg.bundle_specs[1:]

    # Check if the request can be fulfilled by the current placement group.
    if _valid_resource_shape(resources, available_bundles):
        return

    if fn.class_name:
        submitted = "actor"
        name = fn.module_name + "." + fn.class_name + "." + fn.function_name
    else:
        submitted = "task"
        name = fn.module_name + "." + fn.function_name

    # Normalize the resource spec so it looks the same as the placement group bundle.
    main_resources = cur_pg.bundle_specs[0]
    resources = {k: float(v) for k, v in resources.items() if v > 0}

    raise RuntimeError(
        f"No trial resources are available for launching the {submitted} `{name}`. "
        "To resolve this, specify the Tune option:\n\n"
        ">  resources_per_trial=tune.PlacementGroupFactory(\n"
        f">    [{main_resources}] + [{resources}] * N\n"
        ">  )\n\n"
        f"Where `N` is the number of slots to reserve for trial {submitted}s. "
        "If you are using a Ray training library, there might be a utility function "
        "to set this automatically for you. For more information, refer to "
        "https://docs.ray.io/en/latest/tune/tutorials/tune-resources.html"
    )


def init_session(*args, **kwargs) -> None:
    global _session
    if _session:
        raise ValueError(
            "A Train session is already in use. Do not call "
            "`init_session()` manually."
        )

    # Setup hooks for generating placement group resource deadlock warnings.
    from ray import actor, remote_function

    if "TUNE_DISABLE_RESOURCE_CHECKS" not in os.environ:
        actor._actor_launch_hook = _tune_task_and_actor_launch_hook
        remote_function._task_launch_hook = _tune_task_and_actor_launch_hook

    _session = _TrainSession(*args, **kwargs)


def get_session() -> Optional[_TrainSession]:
    return _session


def shutdown_session():
    """Shuts down the initialized session."""
    global _session
    _session = None


def _raise_accelerator_session_misuse():
    """Raises a SessionMisuseError because a utility function was used improperly."""
    raise SessionMisuseError(
        "prepare/accelerate utility functions should be called inside a training "
        "function executed by `Trainer.run`"
    )


def get_accelerator(default_accelerator_cls: Type[Accelerator]) -> Accelerator:
    """The accelerator for this training session.

    If an accelerator has not been set, then this method will construct an
    accelerator using the provided accelerator class.

    Raises:
        SessionMisuseError: if the session is uninitialized.
    """
    session = get_session()
    if session is None:
        _raise_accelerator_session_misuse()
    if session.accelerator is None:
        session.accelerator = default_accelerator_cls()
    return session.accelerator


def set_accelerator(accelerator: Accelerator) -> None:
    """Sets the accelerator for this training session.

    Args:
        accelerator: The accelerator to use for training.

    Raises:
        SessionMisuseError: if the session is unitialized.
        RuntimeError: if the accelerator has already been set.
    """
    session = get_session()
    if session is None:
        _raise_accelerator_session_misuse()
    if session.accelerator is not None:
        raise RuntimeError("Cannot change accelerator once set.")
    session.accelerator = accelerator


def _warn_session_misuse(default_value: Any = None):
    """Warns if fn is being used outside of session and returns ``default_value``."""

    def inner(fn: Callable):
        fn_name = fn.__name__

        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            session = get_session()
            if not session:
                if log_once(f"{SESSION_MISUSE_LOG_ONCE_KEY}-{fn_name}"):
                    warnings.warn(
                        f"`{fn_name}` is meant to only be "
                        "called inside a function that is executed by a Tuner"
                        f" or Trainer. Returning `{default_value}`."
                    )
                return default_value
            return fn(*args, **kwargs)

        return wrapper

    return inner


@PublicAPI(stability="stable")
@_warn_session_misuse()
def report(
    metrics: Dict,
    *,
    checkpoint: Optional[Checkpoint] = None,
    checkpoint_dir_name: Optional[str] = None,
) -> None:
    """Report metrics and optionally save a checkpoint.

    If a checkpoint is provided, it will be
    :ref:`persisted to storage <persistent-storage-guide>`.

    If this is called in multiple distributed training workers:

    - Only the metrics reported by the rank 0 worker will be tracked by Ray Train.
      See :ref:`the metrics logging guide <train-monitoring-and-logging>`.
    - A checkpoint will be registered as long as one or more workers reports
      checkpoint that is not None.
      See the :ref:`checkpointing guide <train-dl-saving-checkpoints>`.
    - Checkpoints from multiple workers will be merged into one directory
      in persistent storage.
      See :ref:`the distributed checkpointing guide <train-distributed-checkpointing>`.

    .. note::

        Each invocation of this method will automatically increment the underlying
        ``training_iteration`` number. The physical meaning of this "iteration" is
        defined by user depending on how often they call ``report``.
        It does not necessarily map to one epoch.

    .. warning::

        All workers must call `ray.train.report` the same number of times
        so that Ray Train can properly synchronize the training state across
        workers. Otherwise, your training will hang.

    .. warning::

        This method does NOT act as a barrier for distributed training workers.
        Workers will upload their checkpoint, then continue training immediately.
        If you need to synchronize workers, you can use a framework-native barrier
        such as `torch.distributed.barrier()`.

    Example:

        .. testcode::

            import tempfile

            from ray import train
            from ray.train import Checkpoint
            from ray.train.torch import TorchTrainer


            def train_func(config):
                start_epoch = 0
                checkpoint = train.get_checkpoint()
                if checkpoint:
                    with checkpoint.as_directory() as checkpoint_dir:
                        # Load back training state
                        ...

                for epoch in range(start_epoch, config.get("num_epochs", 10)):
                    # Do training...

                    metrics = {"loss": ...}

                    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
                       # Save the checkpoint...
                       # torch.save(...)

                        checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)

                        # Example: Only the rank 0 worker uploads the checkpoint.
                        if ray.train.get_context().get_world_rank() == 0:
                            train.report(metrics, checkpoint=checkpoint)
                        else:
                            train.report(metrics, checkpoint=None)

            trainer = TorchTrainer(
                train_func, scaling_config=train.ScalingConfig(num_workers=2)
            )

    Args:
        metrics: The metrics you want to report.
        checkpoint: The optional checkpoint you want to report.
    """
    if checkpoint_dir_name is not None:
        logger.warning(
            "`checkpoint_dir_name` is only supported in the new Ray Train "
            "implementation, which can be enabled with `RAY_TRAIN_V2_ENABLED=1`. "
            "This argument will be ignored."
        )

    # If we are running in a Tune function, switch to `ray.tune.report`.
    from ray.tune.trainable.trainable_fn_utils import _in_tune_session

    if _in_tune_session():
        import ray.tune

        if _v2_migration_warnings_enabled():
            _log_deprecation_warning(
                "`ray.train.report` should be switched to "
                "`ray.tune.report` when running in a function "
                "passed to Ray Tune. This will be an error in the future. "
                "See this issue for more context: "
                "https://github.com/ray-project/ray/issues/49454"
            )
        return ray.tune.report(metrics, checkpoint=checkpoint)

    get_session().report(metrics, checkpoint=checkpoint)


@PublicAPI(stability="stable")
@_warn_session_misuse()
def get_checkpoint() -> Optional[Checkpoint]:
    """Access the latest reported checkpoint to resume from if one exists.

    Example:

        .. testcode::

            import tempfile

            from ray import train
            from ray.train import Checkpoint
            from ray.train.torch import TorchTrainer


            def train_func(config):
                start_epoch = 0
                checkpoint = train.get_checkpoint()
                if checkpoint:
                    with checkpoint.as_directory() as checkpoint_dir:
                        # Load back training state
                        ...

                for epoch in range(start_epoch, config.get("num_epochs", 10)):
                    # Do training...

                    metrics = {"loss": ...}

                    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
                       # Save the checkpoint...

                        checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
                        train.report(metrics, checkpoint=checkpoint)

            trainer = TorchTrainer(
                train_func, scaling_config=train.ScalingConfig(num_workers=2)
            )

    Returns:
        Checkpoint object if the session is currently being resumed.
            Otherwise, return None.
    """
    # If we are running in a Tune function, switch to `ray.tune.get_checkpoint`.
    from ray.tune.trainable.trainable_fn_utils import _in_tune_session

    if _in_tune_session():
        import ray.tune

        if _v2_migration_warnings_enabled():
            _log_deprecation_warning(
                "`ray.train.get_checkpoint` should be switched to "
                "`ray.tune.get_checkpoint` when running in a function "
                "passed to Ray Tune. This will be an error in the future. "
                "See this issue for more context: "
                "https://github.com/ray-project/ray/issues/49454"
            )
        return ray.tune.get_checkpoint()

    return get_session().loaded_checkpoint


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_metadata() -> Dict[str, Any]:
    """User metadata dict passed to the Trainer constructor."""
    return get_session().metadata


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_experiment_name() -> str:
    """Experiment name for the corresponding trial."""
    return get_session().experiment_name


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_name() -> str:
    """Trial name for the corresponding trial."""
    return get_session().trial_name


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_id() -> str:
    """Trial id for the corresponding trial."""
    return get_session().trial_id


@PublicAPI(stability="alpha")
@_warn_session_misuse()
def get_run_id() -> str:
    """Unique Train Run id for the corresponding trial."""
    return get_session().run_id


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_resources() -> "PlacementGroupFactory":
    """Trial resources for the corresponding trial."""
    return get_session().trial_resources


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_dir() -> str:
    """Log directory corresponding to the trial directory for a Tune session.
    If calling from a Train session, this will give the trial directory of its parent
    Tune session.

    .. testcode::

        import ray.tune

        def train_func(config):
            print(ray.tune.get_context().get_trial_dir())

        tuner = ray.tune.Tuner(train_func)
        tuner.fit()

    .. testoutput::
        :options: +MOCK

        /Users/root/ray_results/train_func_2023-07-19_15-01-37/train_func_d620c_00000_0_2023-07-19_15-01-40
    """
    return get_session().trial_dir


@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=1)
def get_world_size() -> int:
    """Get the current world size (i.e. total number of workers) for this run.

    .. testcode::

        import ray
        from ray import train
        from ray.train import ScalingConfig
        from ray.train.tensorflow import TensorflowTrainer

        NUM_WORKERS = 2

        def train_loop_per_worker(config):
            assert train.get_context().get_world_size() == NUM_WORKERS

        train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
        trainer = TensorflowTrainer(
            train_loop_per_worker,
            scaling_config=ScalingConfig(num_workers=NUM_WORKERS),
            datasets={"train": train_dataset}
        )
        trainer.fit()

    .. testoutput::
        :hide:

        ...
    """
    session = get_session()
    if not hasattr(session, "world_size"):
        raise RuntimeError(
            "`get_world_size` can only be called for TrainSession! "
            "Make sure you only use that in `train_loop_per_worker` function"
            "that is passed into `DataParallelTrainer`."
        )
    return session.world_size


@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=0)
def get_world_rank() -> int:
    """Get the world rank of this worker.

    .. testcode::

        import ray
        from ray import train
        from ray.train import ScalingConfig
        from ray.train.tensorflow import TensorflowTrainer

        def train_loop_per_worker(config):
            if train.get_context().get_world_rank() == 0:
                print("Worker 0")

        train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
        trainer = TensorflowTrainer(
            train_loop_per_worker,
            scaling_config=ScalingConfig(num_workers=2),
            datasets={"train": train_dataset}
        )
        trainer.fit()

    .. testoutput::
        :hide:

        ...
    """
    session = get_session()
    if not hasattr(session, "world_rank"):
        raise RuntimeError(
            "`get_world_rank` can only be called for TrainSession! "
            "Make sure you only use that in `train_loop_per_worker` function"
            "that is passed into `DataParallelTrainer`."
        )
    return session.world_rank


@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=0)
def get_local_rank() -> int:
    """Get the local rank of this worker (rank of the worker on its node).

    .. testcode::

        import torch

        import ray
        from ray import train
        from ray.train import ScalingConfig
        from ray.train.torch import TorchTrainer

        def train_loop_per_worker(config):
            if torch.cuda.is_available():
                torch.cuda.set_device(train.get_context().get_local_rank())
            ...

        train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
        trainer = TorchTrainer(
            train_loop_per_worker,
            scaling_config=ScalingConfig(num_workers=2, use_gpu=True),
            datasets={"train": train_dataset}
        )
        trainer.fit()

    .. testoutput::
        :hide:

        ...
    """
    session = get_session()
    if not hasattr(session, "local_rank"):
        raise RuntimeError(
            "`get_local_rank` can only be called for TrainSession! "
            "Make sure you only use that in `train_loop_per_worker` function"
            "that is passed into `DataParallelTrainer`."
        )
    return session.local_rank


@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=0)
def get_local_world_size() -> int:
    """Get the local world size of this node (i.e. number of workers on this node).

    Example:

        .. testcode::

            import ray
            from ray import train
            from ray.train import ScalingConfig
            from ray.train.torch import TorchTrainer

            def train_loop_per_worker():
                print(train.get_context().get_local_world_size())

            train_dataset = ray.data.from_items(
                [{"x": x, "y": x + 1} for x in range(32)])
            trainer = TorchTrainer(train_loop_per_worker,
                scaling_config=ScalingConfig(num_workers=1),
                datasets={"train": train_dataset})
            trainer.fit()

        .. testoutput::
            :hide:

            ...
    """
    session = get_session()
    if not hasattr(session, "local_world_size"):
        raise RuntimeError(
            "`get_local_world_size` can only be called for TrainSession! "
            "Make sure you only use that in `train_loop_per_worker` function"
            "that is passed into `DataParallelTrainer`."
        )
    return session.local_world_size


@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=0)
def get_node_rank() -> int:
    """Get the rank of this node.

    Example:

        .. testcode::

            import ray
            from ray import train
            from ray.train import ScalingConfig
            from ray.train.torch import TorchTrainer

            def train_loop_per_worker():
                print(train.get_context().get_node_rank())

            train_dataset = ray.data.from_items(
                [{"x": x, "y": x + 1} for x in range(32)])
            trainer = TorchTrainer(train_loop_per_worker,
                scaling_config=ScalingConfig(num_workers=1),
                datasets={"train": train_dataset})
            trainer.fit()

        .. testoutput::
            :hide:

            ...
    """
    session = get_session()
    if not hasattr(session, "node_rank"):
        raise RuntimeError(
            "`get_node_rank` can only be called for TrainSession! "
            "Make sure you only use that in `train_loop_per_worker` function"
            "that is passed into `DataParallelTrainer`."
        )
    return session.node_rank


@PublicAPI(stability="stable")
@_warn_session_misuse()
def get_dataset_shard(
    dataset_name: Optional[str] = None,
) -> Optional["DataIterator"]:
    """Returns the :class:`ray.data.DataIterator` shard for this worker.

    Call :meth:`~ray.data.DataIterator.iter_torch_batches` or
    :meth:`~ray.data.DataIterator.to_tf` on this shard to convert it to the
    appropriate framework-specific data type.

    .. testcode::

        import ray
        from ray import train
        from ray.train import ScalingConfig
        from ray.train.torch import TorchTrainer

        def train_loop_per_worker(config):
            ...
            for epoch in range(2):
                # Trainer will automatically handle sharding.
                data_shard = train.get_dataset_shard("train")
                for batch in data_shard.iter_torch_batches():
                    ...

        train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
        trainer = TorchTrainer(
            train_loop_per_worker,
            scaling_config=ScalingConfig(num_workers=2),
            datasets={"train": train_dataset}
        )
        trainer.fit()

    .. testoutput::
        :hide:

        ...

    Args:
        dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then
            specifies which dataset shard to return.

    Returns:
        The ``DataIterator`` shard to use for this worker.
        If no dataset is passed into Trainer, then return None.
    """
    session = get_session()
    if not hasattr(session, "get_dataset_shard"):
        raise RuntimeError(
            "`get_dataset_shard` can only be called for TrainSession! "
            "Make sure you only use that in `train_loop_per_worker` function"
            "that is passed into `DataParallelTrainer`."
        )
    return session.get_dataset_shard(dataset_name)


@DeveloperAPI
@_warn_session_misuse()
def get_storage() -> StorageContext:
    """Returns the :class:`~ray.train._internal.storage.StorageContext` storage
    context which gives advanced access to the filesystem and paths
    configured through `RunConfig`.

    NOTE: This is a developer API, and the `StorageContext` interface may change
    without notice between minor versions.
    """
    return get_session().storage


def _in_ray_train_worker() -> bool:
    """Check if the current process is a Ray Train V1 worker."""
    return bool(get_session()) and get_session().world_rank is not None
