"""Implements the client side of the client/server pickling protocol.

These picklers are aware of the server internals and can find the
references held for the client within the server.

More discussion about the client/server pickling protocol can be found in:

  ray/util/client/client_pickler.py

ServerPickler dumps ray objects from the server into the appropriate stubs.
ClientUnpickler loads stubs from the client and finds their associated handle
in the server instance.
"""
import io
from typing import TYPE_CHECKING, Any

import ray
import ray.cloudpickle as cloudpickle
from ray._private.client_mode_hook import disable_client_hook
from ray.util.client.client_pickler import PickleStub
from ray.util.client.server.server_stubs import (
    ClientReferenceActor,
    ClientReferenceFunction,
)

if TYPE_CHECKING:
    from ray.util.client.server.server import RayletServicer

import pickle  # noqa: F401


class ServerPickler(cloudpickle.CloudPickler):
    def __init__(self, client_id: str, server: "RayletServicer", *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.client_id = client_id
        self.server = server

    def persistent_id(self, obj):
        if isinstance(obj, ray.ObjectRef):
            obj_id = obj.binary()
            if obj_id not in self.server.object_refs[self.client_id]:
                # We're passing back a reference, probably inside a reference.
                # Let's hold onto it.
                self.server.object_refs[self.client_id][obj_id] = obj
            return PickleStub(
                type="Object",
                client_id=self.client_id,
                ref_id=obj_id,
                name=None,
                baseline_options=None,
            )
        elif isinstance(obj, ray.actor.ActorHandle):
            actor_id = obj._actor_id.binary()
            if actor_id not in self.server.actor_refs:
                # We're passing back a handle, probably inside a reference.
                self.server.actor_refs[actor_id] = obj
            if actor_id not in self.server.actor_owners[self.client_id]:
                self.server.actor_owners[self.client_id].add(actor_id)
            return PickleStub(
                type="Actor",
                client_id=self.client_id,
                ref_id=obj._actor_id.binary(),
                name=None,
                baseline_options=None,
            )
        return None


class ClientUnpickler(pickle.Unpickler):
    def __init__(self, server, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.server = server

    def persistent_load(self, pid):
        assert isinstance(pid, PickleStub)
        if pid.type == "Ray":
            return ray
        elif pid.type == "Object":
            return self.server.object_refs[pid.client_id][pid.ref_id]
        elif pid.type == "Actor":
            return self.server.actor_refs[pid.ref_id]
        elif pid.type == "RemoteFuncSelfReference":
            return ClientReferenceFunction(pid.client_id, pid.ref_id)
        elif pid.type == "RemoteFunc":
            return self.server.lookup_or_register_func(
                pid.ref_id, pid.client_id, pid.baseline_options
            )
        elif pid.type == "RemoteActorSelfReference":
            return ClientReferenceActor(pid.client_id, pid.ref_id)
        elif pid.type == "RemoteActor":
            return self.server.lookup_or_register_actor(
                pid.ref_id, pid.client_id, pid.baseline_options
            )
        elif pid.type == "RemoteMethod":
            actor = self.server.actor_refs[pid.ref_id]
            return getattr(actor, pid.name)
        else:
            raise NotImplementedError("Uncovered client data type")


def dumps_from_server(
    obj: Any, client_id: str, server_instance: "RayletServicer", protocol=None
) -> bytes:
    with io.BytesIO() as file:
        sp = ServerPickler(client_id, server_instance, file, protocol=protocol)
        sp.dump(obj)
        return file.getvalue()


def loads_from_client(
    data: bytes,
    server_instance: "RayletServicer",
    *,
    fix_imports=True,
    encoding="ASCII",
    errors="strict"
) -> Any:
    with disable_client_hook():
        if isinstance(data, str):
            raise TypeError("Can't load pickle from unicode string")
        file = io.BytesIO(data)
        return ClientUnpickler(
            server_instance, file, fix_imports=fix_imports, encoding=encoding
        ).load()
