import time
from contextlib import contextmanager
from typing import Any, Dict

import ray as real_ray
import ray.util.client.server.server as ray_client_server
from ray._private.client_mode_hook import disable_client_hook
from ray.job_config import JobConfig
from ray.util.client import ray


@contextmanager
def ray_start_client_server(metadata=None, ray_connect_handler=None, **kwargs):
    with ray_start_client_server_pair(
        metadata=metadata, ray_connect_handler=ray_connect_handler, **kwargs
    ) as pair:
        client, server = pair
        yield client


@contextmanager
def ray_start_client_server_for_address(address):
    """
    Starts a Ray client server that initializes drivers at the specified address.
    """

    def connect_handler(
        job_config: JobConfig = None, **ray_init_kwargs: Dict[str, Any]
    ):
        import ray

        with disable_client_hook():
            if not ray.is_initialized():
                return ray.init(address, job_config=job_config, **ray_init_kwargs)

    with ray_start_client_server(ray_connect_handler=connect_handler) as ray:
        yield ray


@contextmanager
def ray_start_client_server_pair(metadata=None, ray_connect_handler=None, **kwargs):
    ray._inside_client_test = True
    with disable_client_hook():
        assert not ray.is_initialized()
    server = ray_client_server.serve(
        "127.0.0.1", 50051, ray_connect_handler=ray_connect_handler
    )
    ray.connect("127.0.0.1:50051", metadata=metadata, **kwargs)
    try:
        yield ray, server
    finally:
        ray._inside_client_test = False
        ray.disconnect()
        server.stop(0)
        del server
        start = time.monotonic()
        with disable_client_hook():
            while ray.is_initialized():
                time.sleep(1)
                if time.monotonic() - start > 30:
                    raise RuntimeError("Failed to terminate Ray")
        # Allow windows to close processes before moving on
        time.sleep(3)


@contextmanager
def ray_start_cluster_client_server_pair(address):
    ray._inside_client_test = True

    def ray_connect_handler(job_config=None, **ray_init_kwargs):
        real_ray.init(address=address)

    server = ray_client_server.serve(
        "127.0.0.1", 50051, ray_connect_handler=ray_connect_handler
    )
    ray.connect("127.0.0.1:50051")
    try:
        yield ray, server
    finally:
        ray._inside_client_test = False
        ray.disconnect()
        server.stop(0)
