import copy
import logging
import os
import queue
import threading
from typing import Optional

import numpy as np

from ray.air.constants import _ERROR_REPORT_TIMEOUT

logger = logging.getLogger(__name__)


def is_nan(value):
    return np.isnan(value)


def is_nan_or_inf(value):
    return is_nan(value) or np.isinf(value)


class StartTraceback(Exception):
    """These exceptions (and their tracebacks) can be skipped with `skip_exceptions`"""

    pass


class StartTracebackWithWorkerRank(StartTraceback):
    def __init__(self, worker_rank: int) -> None:
        super().__init__()
        self.worker_rank = worker_rank

    def __reduce__(self):
        return (self.__class__, (self.worker_rank,))


def skip_exceptions(exc: Optional[Exception]) -> Exception:
    """Skip all contained `StartTracebacks` to reduce traceback output.

    Returns a shallow copy of the exception with all `StartTracebacks` removed.

    If the RAY_AIR_FULL_TRACEBACKS environment variable is set,
    the original exception (not a copy) is returned.
    """
    should_not_shorten = bool(int(os.environ.get("RAY_AIR_FULL_TRACEBACKS", "0")))

    if should_not_shorten:
        return exc

    if isinstance(exc, StartTraceback):
        # If this is a StartTraceback, skip
        return skip_exceptions(exc.__cause__)

    # Perform a shallow copy to prevent recursive __cause__/__context__.
    new_exc = copy.copy(exc).with_traceback(exc.__traceback__)

    # Make sure nested exceptions are properly skipped.
    cause = getattr(exc, "__cause__", None)
    if cause:
        new_exc.__cause__ = skip_exceptions(cause)

    return new_exc


def exception_cause(exc: Optional[Exception]) -> Optional[Exception]:
    if not exc:
        return None

    return getattr(exc, "__cause__", None)


class RunnerThread(threading.Thread):
    """Supervisor thread that runs your script."""

    def __init__(self, *args, error_queue, **kwargs):
        threading.Thread.__init__(self, *args, **kwargs)
        self._error_queue = error_queue
        self._ret = None

    def _propagate_exception(self, e: BaseException):
        try:
            # report the error but avoid indefinite blocking which would
            # prevent the exception from being propagated in the unlikely
            # case that something went terribly wrong
            self._error_queue.put(e, block=True, timeout=_ERROR_REPORT_TIMEOUT)
        except queue.Full:
            logger.critical(
                (
                    "Runner Thread was unable to report error to main "
                    "function runner thread. This means a previous error "
                    "was not processed. This should never happen."
                )
            )

    def run(self):
        try:
            self._ret = self._target(*self._args, **self._kwargs)
        except StopIteration:
            logger.debug(
                (
                    "Thread runner raised StopIteration. Interpreting it as a "
                    "signal to terminate the thread without error."
                )
            )
        except SystemExit as e:
            # Do not propagate up for graceful termination.
            if e.code == 0:
                logger.debug(
                    (
                        "Thread runner raised SystemExit with error code 0. "
                        "Interpreting it as a signal to terminate the thread "
                        "without error."
                    )
                )
            else:
                # If non-zero exit code, then raise exception to main thread.
                self._propagate_exception(e)
        except BaseException as e:
            # Propagate all other exceptions to the main thread.
            self._propagate_exception(e)

    def join(self, timeout=None):
        super(RunnerThread, self).join(timeout)
        return self._ret
