from enum import Enum

from packaging.version import Version

from ray.rllib.utils.checkpoints import try_import_msgpack
from ray.util.annotations import DeveloperAPI

msgpack = None


@DeveloperAPI
class RLlink(Enum):
    PROTOCOL_VERSION = Version("0.0.1")

    # Requests: Client (external env) -> Server (RLlib).
    # ----
    # Ping command (initial handshake).
    PING = "PING"
    # List of episodes (similar to what an EnvRunner.sample() call would return).
    EPISODES = "EPISODES"
    # Request state (e.g. model weights).
    GET_STATE = "GET_STATE"
    # Request Algorithm config.
    GET_CONFIG = "GET_CONFIG"
    # Send episodes and request the next state update right after that.
    # Clients sending this message should wait for a SET_STATE message as an immediate
    # response. Useful for external samplers that must collect on-policy data.
    EPISODES_AND_GET_STATE = "EPISODES_AND_GET_STATE"

    # Responses: Server (RLlib) -> Client (external env).
    # ----
    # Pong response (initial handshake).
    PONG = "PONG"
    # Set state (e.g. model weights).
    SET_STATE = "SET_STATE"
    # Set Algorithm  config.
    SET_CONFIG = "SET_CONFIG"

    # @OldAPIStack (to be deprecated soon).
    ACTION_SPACE = "ACTION_SPACE"
    OBSERVATION_SPACE = "OBSERVATION_SPACE"
    GET_WORKER_ARGS = "GET_WORKER_ARGS"
    GET_WEIGHTS = "GET_WEIGHTS"
    REPORT_SAMPLES = "REPORT_SAMPLES"
    START_EPISODE = "START_EPISODE"
    GET_ACTION = "GET_ACTION"
    LOG_ACTION = "LOG_ACTION"
    LOG_RETURNS = "LOG_RETURNS"
    END_EPISODE = "END_EPISODE"

    def __str__(self):
        return self.name


@DeveloperAPI
def send_rllink_message(sock_, message: dict):
    """Sends a message to the client with a length header."""
    global msgpack
    if msgpack is None:
        msgpack = try_import_msgpack(error=True)

    body = msgpack.packb(message, use_bin_type=True)  # .encode("utf-8")
    header = str(len(body)).zfill(8).encode("utf-8")
    try:
        sock_.sendall(header + body)
    except Exception as e:
        raise ConnectionError(
            f"Error sending message {message} to server on socket {sock_}! "
            f"Original error was: {e}"
        )


@DeveloperAPI
def get_rllink_message(sock_):
    """Receives a message from the client following the length-header protocol."""
    global msgpack
    if msgpack is None:
        msgpack = try_import_msgpack(error=True)

    try:
        # Read the length header (8 bytes)
        header = _get_num_bytes(sock_, 8)
        msg_length = int(header.decode("utf-8"))
        # Read the message body
        body = _get_num_bytes(sock_, msg_length)
        # Decode JSON.
        message = msgpack.unpackb(body, raw=False)  # .loads(body.decode("utf-8"))
        # Check for proper protocol.
        if "type" not in message:
            raise ConnectionError(
                "Protocol Error! Message from peer does not contain `type` field."
            )
        return RLlink(message.pop("type")), message
    except Exception as e:
        raise ConnectionError(
            f"Error receiving message from peer on socket {sock_}! "
            f"Original error was: {e}"
        )


def _get_num_bytes(sock_, num_bytes):
    """Helper function to receive a specific number of bytes."""
    data = b""
    while len(data) < num_bytes:
        packet = sock_.recv(num_bytes - len(data))
        if not packet:
            raise ConnectionError(f"No data received from socket {sock_}!")
        data += packet
    return data
