import logging
import numbers
from typing import Any, Callable, Dict, List, Optional, Tuple

from ray._private import ray_constants
from ray._private.dict import flatten_dict
from ray.air._internal.util import is_nan
from ray.air.config import MAX
from ray.train import Checkpoint, CheckpointConfig
from ray.train._internal.session import _TrainingResult
from ray.train._internal.storage import _delete_fs_path
from ray.train.constants import TUNE_ONLY_STORE_CHECKPOINT_SCORE_ATTRIBUTE

logger = logging.getLogger(__name__)


def _insert_into_sorted_list(
    list: List[_TrainingResult],
    item: _TrainingResult,
    key: Callable[[_TrainingResult], Any],
    checkpoint_to_report_index: Optional[Dict[Checkpoint, int]] = None,
):
    """Insert an item into a sorted list with a custom key function.

    Args:
        list: The list to insert the item into.
        item: The item to insert.
        key: The key function to use to sort the list.
        checkpoint_to_report_index: A dictionary mapping checkpoints to report indices.
            Used to break ties when scores are equal.
    """
    checkpoint_to_report_index = checkpoint_to_report_index or {}
    # TODO: optimize this with sortedlist, batching, etc
    i = 0
    while i < len(list):
        # When scores are equal, later checkpoints are later in the list.
        list_item_key, item_key = key(list[i]), key(item)
        if list_item_key > item_key or (
            list_item_key == item_key
            and checkpoint_to_report_index.get(list[i].checkpoint, 0)
            > checkpoint_to_report_index.get(item.checkpoint, 0)
        ):
            break
        i += 1
    list.insert(i, item)


class _CheckpointManager:
    """Checkpoint manager that handles checkpoint book-keeping for a trial.

    The main purpose of this abstraction is to keep the top K checkpoints based on
    recency/a user-provided metric.

    NOTE: This class interacts with `_TrainingResult` objects, which are
    (checkpoint, metrics) pairs. This is to order checkpoints by metrics.

    Args:
        checkpoint_config: Defines how many and which checkpoints to keep.
    """

    def __init__(self, checkpoint_config: Optional[CheckpointConfig]):
        self._checkpoint_config = checkpoint_config or CheckpointConfig()

        # List of checkpoints ordered by ascending score.
        self._checkpoint_results: List[_TrainingResult] = []

        # The latest registered checkpoint.
        # This should never be immediately deleted upon registration,
        # even if it's not in the top K checkpoints, based on score.
        self._latest_checkpoint_result: Optional[_TrainingResult] = None

        if (
            self._checkpoint_config.num_to_keep is not None
            and self._checkpoint_config.num_to_keep <= 0
        ):
            raise ValueError(
                f"`num_to_keep` must >= 1, got: "
                f"{self._checkpoint_config.num_to_keep}"
            )

    @property
    def checkpoint_config(self):
        return self._checkpoint_config

    def register_checkpoint(self, checkpoint_result: _TrainingResult):
        """Register new checkpoint and add to bookkeeping.

        This method will register a new checkpoint and add it to the internal
        bookkeeping logic. This means the checkpoint manager will decide if
        this checkpoint should be kept, and if older or worse performing
        checkpoints should be deleted.

        Args:
            checkpoint: Tracked checkpoint object to add to bookkeeping.
        """
        self._latest_checkpoint_result = checkpoint_result

        score_attr = self._checkpoint_config.checkpoint_score_attribute
        if ray_constants.env_bool(TUNE_ONLY_STORE_CHECKPOINT_SCORE_ATTRIBUTE, False):
            metrics = (
                {score_attr: checkpoint_result.metrics[score_attr]}
                if score_attr in checkpoint_result.metrics
                else {}
            )
            checkpoint_result = _TrainingResult(
                checkpoint=checkpoint_result.checkpoint,
                metrics=metrics,
            )

        if score_attr is not None and score_attr in checkpoint_result.metrics:
            # If we're ordering by a score, insert the checkpoint
            # so that the list remains sorted.
            _insert_into_sorted_list(
                self._checkpoint_results,
                checkpoint_result,
                key=self._get_checkpoint_score,
            )
        else:
            # If no metric is provided, just append (ordering by time of registration).
            self._checkpoint_results.append(checkpoint_result)

        if self._checkpoint_config.num_to_keep is not None:
            # Delete the bottom (N - K) checkpoints
            worst_results = set(
                self._checkpoint_results[: -self._checkpoint_config.num_to_keep]
            )
            # Except for the latest checkpoint.
            results_to_delete = worst_results - {self._latest_checkpoint_result}

            # Update internal state before actually deleting them.
            self._checkpoint_results = [
                checkpoint_result
                for checkpoint_result in self._checkpoint_results
                if checkpoint_result not in results_to_delete
            ]

            for checkpoint_result in results_to_delete:
                checkpoint = checkpoint_result.checkpoint
                logger.debug("Deleting checkpoint: ", checkpoint)
                _delete_fs_path(fs=checkpoint.filesystem, fs_path=checkpoint.path)

    def _get_checkpoint_score(
        self, checkpoint: _TrainingResult
    ) -> Tuple[bool, numbers.Number]:
        """Get the score for a checkpoint, according to checkpoint config.

        If `mode="min"`, the metric is negated so that the lowest score is
        treated as the best.

        Returns:
            Tuple: A tuple of (not_is_nan: bool, score: numbers.Number).
                This score orders: nan values < float("-inf") < valid numeric metrics
        """
        checkpoint_score_attribute = self._checkpoint_config.checkpoint_score_attribute
        if checkpoint_score_attribute:
            flat_metrics = flatten_dict(checkpoint.metrics)
            try:
                checkpoint_result = flat_metrics[checkpoint_score_attribute]
            except KeyError:
                valid_keys = list(flat_metrics.keys())
                logger.error(
                    f"Result dict has no key: {checkpoint_score_attribute}. "
                    f"checkpoint_score_attr must be set to a key in the "
                    f"result dict. Valid keys are: {valid_keys}"
                )
                checkpoint_result = float("-inf")
        else:
            checkpoint_result = float("-inf")

        checkpoint_score_order = self._checkpoint_config.checkpoint_score_order
        order_factor = 1.0 if checkpoint_score_order == MAX else -1.0

        checkpoint_score = order_factor * checkpoint_result

        if not isinstance(checkpoint_score, numbers.Number):
            raise ValueError(
                f"Unable to persist checkpoint for "
                f"checkpoint_score_attribute: "
                f"{checkpoint_score_attribute} with value "
                f"{checkpoint_score}. "
                f"This attribute must be numerical."
            )

        return (
            (not is_nan(checkpoint_score), checkpoint_score)
            if not is_nan(checkpoint_score)
            else (False, float("-inf"))
        )

    @property
    def best_checkpoint_result(self) -> Optional[_TrainingResult]:
        return self._checkpoint_results[-1] if self._checkpoint_results else None

    @property
    def latest_checkpoint_result(self) -> Optional[_TrainingResult]:
        return self._latest_checkpoint_result

    @property
    def best_checkpoint_results(self) -> List[_TrainingResult]:
        if self._checkpoint_config.num_to_keep is None:
            return self._checkpoint_results
        return self._checkpoint_results[-self._checkpoint_config.num_to_keep :]
