import importlib
import threading
from contextlib import nullcontext
from typing import TYPE_CHECKING, ContextManager, List, Optional, Type

import ray
from ray._private.accelerators import get_accelerator_manager_for_resource
from ray.experimental.channel.communicator import Communicator

if TYPE_CHECKING:
    import torch

# The accelerator context singleton on this process.
_accelerator_context_lock = threading.Lock()
_default_accelerator_context: Optional["AcceleratorContext"] = None
_global_custom_context: Optional["AcceleratorContext"] = None


class AcceleratorContext:
    """
    Provides a unified interface for managing different accelerator backends
    This includes stream management, event creation, device context control,
    and communicator support for distributed communication.
    """

    def __init__(self, torch_module_name: str, communicator_cls: Type[Communicator]):
        """
        Initializes an accelerator context with the specified torch device module
        and communicator class.

        Args:
            torch_module_name: Name of the torch device module (e.g., "cuda", "cpu").
            communicator_cls: Class used to handle communication.
        """

        # The name of the torch module (e.g., 'cuda', 'npu')
        self._torch_module_name: str = torch_module_name
        # The Communicator class used to manage communication
        self._communicator_cls: Type[Communicator] = communicator_cls

        # Import the torch backend module (e.g., torch.cuda) if the device is not 'cpu'.
        if torch_module_name != "cpu":
            self._torch_mod = importlib.import_module(f"torch.{torch_module_name}")

    @staticmethod
    def get() -> "AcceleratorContext":
        """
        Returns the singleton instance of the accelerator context.

        If a custom accelerator has been registered, initializes the context
        based on the registration. Otherwise, selects an appropriate runtime
        based on the available device (CUDA or CPU) and registers the
        corresponding default communicator.

        Returns:
            AcceleratorContext: A singleton instance of the appropriate
            runtime context.
        """

        global _default_accelerator_context, _global_custom_context

        with _accelerator_context_lock:
            if _global_custom_context is not None:
                return _global_custom_context

            if _default_accelerator_context is None:
                if len(ray.get_gpu_ids()) > 0:
                    from ray.experimental.channel.nccl_group import _NcclGroup

                    _default_accelerator_context = AcceleratorContext(
                        "cuda", _NcclGroup
                    )
                else:
                    from ray.experimental.channel.cpu_communicator import (
                        CPUCommunicator,
                    )

                    _default_accelerator_context = AcceleratorContext(
                        "cpu", CPUCommunicator
                    )

            return _default_accelerator_context

    @staticmethod
    def set(accelerator_context: "AcceleratorContext") -> None:
        """
        Overwrites the default accelerator context.

        Args:
            accelerator_context: The context to register.
        """
        global _global_custom_context

        # Accelerator context is registered.
        _global_custom_context = accelerator_context

    def get_accelerator_devices(self) -> List["torch.device"]:
        """
        Gets the torch device list configured for this process.

        Returns:
            List[torch.device]: The torch device list.
        """
        import torch

        if self._torch_module_name == "cpu":
            return [torch.device("cpu")]

        if self._torch_module_name == "cuda":
            accelerator_ids = [str(id) for id in ray.get_gpu_ids()]
            accelerator_manager = get_accelerator_manager_for_resource("GPU")
        else:
            accelerator_ids = [
                str(id)
                for id in ray.get_runtime_context().get_accelerator_ids()[
                    self._torch_module_name.upper()
                ]
            ]
            accelerator_manager = get_accelerator_manager_for_resource(
                self._torch_module_name.upper()
            )

        device_ids = []

        if len(accelerator_ids) > 0:
            accelerator_visible_list = (
                accelerator_manager.get_current_process_visible_accelerator_ids()
            )
            if accelerator_visible_list is None:
                accelerator_visible_list = []

            # If there are multiple Accelerators, return a list of devices.
            # If using fractional Accelerators, these IDs are not guaranteed
            # to be unique across different processes.
            for accelerator_id in accelerator_ids:
                try:
                    device_ids.append(accelerator_visible_list.index(accelerator_id))
                except ValueError:
                    raise RuntimeError(
                        f"{accelerator_manager.get_visible_accelerator_ids_env_var()} set incorrectly. "
                        f"expected to include {accelerator_id}. "
                        "Did you override this environment"
                        " variable? If not, please help file an issue on Github."
                    )

        else:
            # If called on the driver or outside of Ray Train, return the
            # 0th device.
            device_ids.append(0)

        return [
            torch.device(f"{self._torch_module_name}:{device_id}")
            for device_id in device_ids
        ]

    def get_device_context(self, device: "torch.device") -> ContextManager:
        """
        Retrieves the context manager for the specified accelerator device.
        There is no device context for CPU, returning a nullcontext.

        Args:
            device: The target device for which the context manager is required.

        Returns:
            ContextManager: A context manager specific to the device type.
        """
        if device.type == "cpu":
            return nullcontext()

        return self._torch_mod.device(device)

    def current_stream(self):
        """
        Retrieves the current execution stream for the accelerator device.
        """
        return self._torch_mod.current_stream()

    def create_event(self):
        """
        Creates an event object for the accelerator device.
        """
        return self._torch_mod.Event()

    def generate_communicator_id(self) -> str:
        """
        Generates a communication identifier for communication group.
        """
        return self._communicator_cls.generate_communicator_id()

    def create_communicator(self, *args, **kwargs) -> Communicator:
        """
        Creates a communication group for collective operations.
        """
        return self._communicator_cls(*args, **kwargs)

    @property
    def module_name(self) -> str:
        """
        Gets the name of the torch module backing the accelerator.
        """
        return self._torch_module_name

    @property
    def communicator_cls(self) -> Optional[Type[Communicator]]:
        """
        Returns the communicator class.
        """
        return self._communicator_cls

    @property
    def accelerator_count(self) -> int:
        """
        Returns the number of accelerators assigned by ray.
        """
        if self._torch_module_name == "cuda":
            return len(ray.get_gpu_ids())
        else:
            accelerator_ids = ray.get_runtime_context().get_accelerator_ids()
            return len(accelerator_ids.get(self._torch_module_name.upper(), []))


def register_accelerator_context(
    torch_module_name: str, communicator_cls: Type[Communicator]
):
    """
    Registers the accelerator context with the specified device type and communicator.

    Args:
        torch_module_name: The name of the device module under torch.
        communicator_cls: The communicator class associated with the device.
    """
    accelerator_context = AcceleratorContext(torch_module_name, communicator_cls)
    AcceleratorContext.set(accelerator_context)


def is_accelerator_context_registered():
    """
    Checks whether a custom accelerator context has been registered.

    Returns:
        bool: True if a custom accelerator context is registered
              (_global_custom_context is not None), False otherwise.
    """
    if _global_custom_context is not None:
        return True
    return False
