import json
import logging
import os
import random
import shutil
import subprocess
import sys
import tempfile
import threading
import time
from typing import Any, Dict, Optional

import yaml

import ray
from ray._common.network_utils import build_address
from ray._private.dict import deep_update
from ray.autoscaler._private.fake_multi_node.node_provider import (
    FAKE_DOCKER_DEFAULT_CLIENT_PORT,
    FAKE_DOCKER_DEFAULT_GCS_PORT,
)
from ray.util.queue import Empty, Queue

logger = logging.getLogger(__name__)

DEFAULT_DOCKER_IMAGE = "rayproject/ray:nightly-py{major}{minor}-cpu"


class ResourcesNotReadyError(RuntimeError):
    pass


class DockerCluster:
    """Docker cluster wrapper.

    Creates a directory for starting a fake multinode docker cluster.

    Includes APIs to update the cluster config as needed in tests,
    and to start and connect to the cluster.
    """

    def __init__(self, config: Optional[Dict[str, Any]] = None):
        self._base_config_file = os.path.join(
            os.path.dirname(__file__), "example_docker.yaml"
        )
        self._tempdir = None
        self._config_file = None
        self._nodes_file = None
        self._nodes = {}
        self._status_file = None
        self._status = {}
        self._partial_config = config
        self._cluster_config = None
        self._docker_image = None

        self._monitor_script = os.path.join(
            os.path.dirname(__file__), "docker_monitor.py"
        )
        self._monitor_process = None

        self._execution_thread = None
        self._execution_event = threading.Event()
        self._execution_queue = None

    @property
    def config_file(self):
        return self._config_file

    @property
    def cluster_config(self):
        return self._cluster_config

    @property
    def cluster_dir(self):
        return self._tempdir

    @property
    def gcs_port(self):
        return self._cluster_config.get("provider", {}).get(
            "host_gcs_port", FAKE_DOCKER_DEFAULT_GCS_PORT
        )

    @property
    def client_port(self):
        return self._cluster_config.get("provider", {}).get(
            "host_client_port", FAKE_DOCKER_DEFAULT_CLIENT_PORT
        )

    def connect(self, client: bool = True, timeout: int = 120, **init_kwargs):
        """Connect to the docker-compose Ray cluster.

        Assumes the cluster is at RAY_TESTHOST (defaults to
        ``127.0.0.1``).

        Args:
            client: If True, uses Ray client to connect to the
                cluster. If False, uses GCS to connect to the cluster.
            timeout: Connection timeout in seconds.
            **init_kwargs: kwargs to pass to ``ray.init()``.

        """
        host = os.environ.get("RAY_TESTHOST", "127.0.0.1")

        if client:
            port = self.client_port
            address = f"ray://{build_address(host, port)}"
        else:
            port = self.gcs_port
            address = build_address(host, port)

        timeout_at = time.monotonic() + timeout
        while time.monotonic() < timeout_at:
            try:
                ray.init(address, **init_kwargs)
                self.wait_for_resources({"CPU": 1})
            except ResourcesNotReadyError:
                time.sleep(1)
                continue
            else:
                break

        try:
            ray.cluster_resources()
        except Exception as e:
            raise RuntimeError(f"Timed out connecting to Ray: {e}")

    def remote_execution_api(self) -> "RemoteAPI":
        """Create an object to control cluster state from within the cluster."""
        self._execution_queue = Queue(actor_options={"num_cpus": 0})
        stop_event = self._execution_event

        def entrypoint():
            while not stop_event.is_set():
                try:
                    cmd, kwargs = self._execution_queue.get(timeout=1)
                except Empty:
                    continue

                if cmd == "kill_node":
                    self.kill_node(**kwargs)

        self._execution_thread = threading.Thread(target=entrypoint)
        self._execution_thread.start()

        return RemoteAPI(self._execution_queue)

    @staticmethod
    def wait_for_resources(resources: Dict[str, float], timeout: int = 60):
        """Wait until Ray cluster resources are available

        Args:
            resources: Minimum resources needed before
                this function returns.
            timeout: Timeout in seconds.

        """
        timeout = time.monotonic() + timeout

        available = ray.cluster_resources()
        while any(available.get(k, 0.0) < v for k, v in resources.items()):
            if time.monotonic() > timeout:
                raise ResourcesNotReadyError(
                    f"Timed out waiting for resources: {resources}"
                )
            time.sleep(1)
            available = ray.cluster_resources()

    def update_config(self, config: Optional[Dict[str, Any]] = None):
        """Update autoscaling config.

        Does a deep update of the base config with a new configuration.
        This can change autoscaling behavior.

        Args:
            config: Partial config to update current
                config with.

        """
        assert self._tempdir, "Call setup() first"

        config = config or {}

        if config:
            self._partial_config = config

        if not config.get("provider", {}).get("image"):
            # No image specified, trying to parse from buildkite
            docker_image = os.environ.get("RAY_DOCKER_IMAGE", None)

            if not docker_image:
                # If still no docker image, use one according to Python version
                mj = sys.version_info.major
                mi = sys.version_info.minor

                docker_image = DEFAULT_DOCKER_IMAGE.format(major=mj, minor=mi)

            self._docker_image = docker_image

        with open(self._base_config_file, "rt") as f:
            cluster_config = yaml.safe_load(f)

        if self._partial_config:
            deep_update(cluster_config, self._partial_config, new_keys_allowed=True)

        if self._docker_image:
            cluster_config["provider"]["image"] = self._docker_image

        cluster_config["provider"]["shared_volume_dir"] = self._tempdir

        self._cluster_config = cluster_config

        with open(self._config_file, "wt") as f:
            yaml.safe_dump(self._cluster_config, f)

        logging.info(f"Updated cluster config to: {self._cluster_config}")

    def maybe_pull_image(self):
        if self._docker_image:
            try:
                images_str = subprocess.check_output(
                    f"docker image inspect {self._docker_image}", shell=True
                )
                images = json.loads(images_str)
            except Exception as e:
                logger.error(f"Error inspecting image {self._docker_image}: {e}")
                return

            if not images:
                try:
                    subprocess.check_call(
                        f"docker pull {self._docker_image}", shell=True
                    )
                except Exception as e:
                    logger.error(f"Error pulling image {self._docker_image}: {e}")

    def setup(self):
        """Setup docker compose cluster environment.

        Creates the temporary directory, writes the initial config file,
        and pulls the docker image, if required.
        """
        self._tempdir = tempfile.mkdtemp(dir=os.environ.get("RAY_TEMPDIR", None))
        os.chmod(self._tempdir, 0o777)
        self._config_file = os.path.join(self._tempdir, "cluster.yaml")
        self._nodes_file = os.path.join(self._tempdir, "nodes.json")
        self._status_file = os.path.join(self._tempdir, "status.json")
        self.update_config()
        self.maybe_pull_image()

    def teardown(self, keep_dir: bool = False):
        """Tear down docker compose cluster environment.

        Args:
            keep_dir: If True, cluster directory
                will not be removed after termination.
        """
        if not keep_dir:
            shutil.rmtree(self._tempdir)
        self._tempdir = None
        self._config_file = None

    def _start_monitor(self):
        self._monitor_process = subprocess.Popen(
            [sys.executable, self._monitor_script, self.config_file]
        )
        time.sleep(2)

    def _stop_monitor(self):
        if self._monitor_process:
            self._monitor_process.wait(timeout=30)
            if self._monitor_process.poll() is None:
                self._monitor_process.terminate()
        self._monitor_process = None

    def start(self):
        """Start docker compose cluster.

        Starts the monitor process and runs ``ray up``.
        """
        self._start_monitor()

        subprocess.check_call(
            f"RAY_FAKE_CLUSTER=1 ray up -y {self.config_file}", shell=True
        )

    def stop(self):
        """Stop docker compose cluster.

        Runs ``ray down`` and stops the monitor process.
        """
        if ray.is_initialized:
            ray.shutdown()

        subprocess.check_call(
            f"RAY_FAKE_CLUSTER=1 ray down -y {self.config_file}", shell=True
        )

        self._stop_monitor()
        self._execution_event.set()

    def _update_nodes(self):
        with open(self._nodes_file, "rt") as f:
            self._nodes = json.load(f)

    def _update_status(self):
        with open(self._status_file, "rt") as f:
            self._status = json.load(f)

    def _get_node(
        self,
        node_id: Optional[str] = None,
        num: Optional[int] = None,
        rand: Optional[str] = None,
    ) -> str:
        self._update_nodes()
        if node_id:
            assert (
                not num and not rand
            ), "Only provide either `node_id`, `num`, or `random`."
        elif num:
            assert (
                not node_id and not rand
            ), "Only provide either `node_id`, `num`, or `random`."
            base = "fffffffffffffffffffffffffffffffffffffffffffffffffff"
            node_id = base + str(num).zfill(5)
        elif rand:
            assert (
                not node_id and not num
            ), "Only provide either `node_id`, `num`, or `random`."
            assert rand in [
                "worker",
                "any",
            ], "`random` must be one of ['worker', 'any']"
            choices = list(self._nodes.keys())
            if rand == "worker":
                choices.remove(
                    "fffffffffffffffffffffffffffffffffffffffffffffffffff00000"
                )
            # Else: any
            node_id = random.choice(choices)

        assert node_id in self._nodes, f"Node with ID {node_id} is not in active nodes."
        return node_id

    def _get_docker_container(self, node_id: str) -> Optional[str]:
        self._update_status()
        node_status = self._status.get(node_id)
        if not node_status:
            return None

        return node_status["Name"]

    def kill_node(
        self,
        node_id: Optional[str] = None,
        num: Optional[int] = None,
        rand: Optional[str] = None,
    ):
        """Kill node.

        If ``node_id`` is given, kill that node.

        If ``num`` is given, construct node_id from this number, and kill
        that node.

        If ``rand`` is given (as either ``worker`` or ``any``), kill a random
        node.
        """
        node_id = self._get_node(node_id=node_id, num=num, rand=rand)
        container = self._get_docker_container(node_id=node_id)
        subprocess.check_call(f"docker kill {container}", shell=True)


class RemoteAPI:
    """Remote API to control cluster state from within cluster tasks.

    This API uses a Ray queue to interact with an execution thread on the
    host machine that will execute commands passed to the queue.

    Instances of this class can be serialized and passed to Ray remote actors
    to interact with cluster state (but they can also be used outside actors).

    The API subset is limited to specific commands.

    Args:
        queue: Ray queue to push command instructions to.

    """

    def __init__(self, queue: Queue):
        self._queue = queue

    def kill_node(
        self,
        node_id: Optional[str] = None,
        num: Optional[int] = None,
        rand: Optional[str] = None,
    ):
        self._queue.put(("kill_node", dict(node_id=node_id, num=num, rand=rand)))
