# @OldAPIStack
from collections import OrderedDict
from typing import Dict, List, Tuple, Union

import gymnasium as gym

from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ModelConfigDict, TensorType

try:
    from dnc import DNC
except ModuleNotFoundError:
    print("dnc module not found. Did you forget to 'pip install dnc'?")
    raise

torch, nn = try_import_torch()


class DNCMemory(TorchModelV2, nn.Module):
    """Differentiable Neural Computer wrapper around ixaxaar's DNC implementation,
    see https://github.com/ixaxaar/pytorch-dnc"""

    DEFAULT_CONFIG = {
        "dnc_model": DNC,
        # Number of controller hidden layers
        "num_hidden_layers": 1,
        # Number of weights per controller hidden layer
        "hidden_size": 64,
        # Number of LSTM units
        "num_layers": 1,
        # Number of read heads, i.e. how many addrs are read at once
        "read_heads": 4,
        # Number of memory cells in the controller
        "nr_cells": 32,
        # Size of each cell
        "cell_size": 16,
        # LSTM activation function
        "nonlinearity": "tanh",
        # Observation goes through this torch.nn.Module before
        # feeding to the DNC
        "preprocessor": torch.nn.Sequential(torch.nn.Linear(64, 64), torch.nn.Tanh()),
        # Input size to the preprocessor
        "preprocessor_input_size": 64,
        # The output size of the preprocessor
        # and the input size of the dnc
        "preprocessor_output_size": 64,
    }

    MEMORY_KEYS = [
        "memory",
        "link_matrix",
        "precedence",
        "read_weights",
        "write_weights",
        "usage_vector",
    ]

    def __init__(
        self,
        obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        num_outputs: int,
        model_config: ModelConfigDict,
        name: str,
        **custom_model_kwargs,
    ):
        nn.Module.__init__(self)
        super(DNCMemory, self).__init__(
            obs_space, action_space, num_outputs, model_config, name
        )
        self.num_outputs = num_outputs
        self.obs_dim = gym.spaces.utils.flatdim(obs_space)
        self.act_dim = gym.spaces.utils.flatdim(action_space)

        self.cfg = dict(self.DEFAULT_CONFIG, **custom_model_kwargs)
        assert (
            self.cfg["num_layers"] == 1
        ), "num_layers != 1 has not been implemented yet"
        self.cur_val = None

        self.preprocessor = torch.nn.Sequential(
            torch.nn.Linear(self.obs_dim, self.cfg["preprocessor_input_size"]),
            self.cfg["preprocessor"],
        )

        self.logit_branch = SlimFC(
            in_size=self.cfg["hidden_size"],
            out_size=self.num_outputs,
            activation_fn=None,
            initializer=torch.nn.init.xavier_uniform_,
        )

        self.value_branch = SlimFC(
            in_size=self.cfg["hidden_size"],
            out_size=1,
            activation_fn=None,
            initializer=torch.nn.init.xavier_uniform_,
        )

        self.dnc: Union[None, DNC] = None

    def get_initial_state(self) -> List[TensorType]:
        ctrl_hidden = [
            torch.zeros(self.cfg["num_hidden_layers"], self.cfg["hidden_size"]),
            torch.zeros(self.cfg["num_hidden_layers"], self.cfg["hidden_size"]),
        ]
        m = self.cfg["nr_cells"]
        r = self.cfg["read_heads"]
        w = self.cfg["cell_size"]
        memory = [
            torch.zeros(m, w),  # memory
            torch.zeros(1, m, m),  # link_matrix
            torch.zeros(1, m),  # precedence
            torch.zeros(r, m),  # read_weights
            torch.zeros(1, m),  # write_weights
            torch.zeros(m),  # usage_vector
        ]

        read_vecs = torch.zeros(w * r)

        state = [*ctrl_hidden, read_vecs, *memory]
        assert len(state) == 9
        return state

    def value_function(self) -> TensorType:
        assert self.cur_val is not None, "must call forward() first"
        return self.cur_val

    def unpack_state(
        self,
        state: List[TensorType],
    ) -> Tuple[List[Tuple[TensorType, TensorType]], Dict[str, TensorType], TensorType]:
        """Given a list of tensors, reformat for self.dnc input"""
        assert len(state) == 9, "Failed to verify unpacked state"
        ctrl_hidden: List[Tuple[TensorType, TensorType]] = [
            (
                state[0].permute(1, 0, 2).contiguous(),
                state[1].permute(1, 0, 2).contiguous(),
            )
        ]
        read_vecs: TensorType = state[2]
        memory: List[TensorType] = state[3:]
        memory_dict: OrderedDict[str, TensorType] = OrderedDict(
            zip(self.MEMORY_KEYS, memory)
        )

        return ctrl_hidden, memory_dict, read_vecs

    def pack_state(
        self,
        ctrl_hidden: List[Tuple[TensorType, TensorType]],
        memory_dict: Dict[str, TensorType],
        read_vecs: TensorType,
    ) -> List[TensorType]:
        """Given the dnc output, pack it into a list of tensors
        for rllib state. Order is ctrl_hidden, read_vecs, memory_dict"""
        state = []
        ctrl_hidden = [
            ctrl_hidden[0][0].permute(1, 0, 2),
            ctrl_hidden[0][1].permute(1, 0, 2),
        ]
        state += ctrl_hidden
        assert len(state) == 2, "Failed to verify packed state"
        state.append(read_vecs)
        assert len(state) == 3, "Failed to verify packed state"
        state += memory_dict.values()
        assert len(state) == 9, "Failed to verify packed state"
        return state

    def validate_unpack(self, dnc_output, unpacked_state):
        """Ensure the unpacked state shapes match the DNC output"""
        s_ctrl_hidden, s_memory_dict, s_read_vecs = unpacked_state
        ctrl_hidden, memory_dict, read_vecs = dnc_output

        for i in range(len(ctrl_hidden)):
            for j in range(len(ctrl_hidden[i])):
                assert s_ctrl_hidden[i][j].shape == ctrl_hidden[i][j].shape, (
                    "Controller state mismatch: got "
                    f"{s_ctrl_hidden[i][j].shape} should be "
                    f"{ctrl_hidden[i][j].shape}"
                )

        for k in memory_dict:
            assert s_memory_dict[k].shape == memory_dict[k].shape, (
                "Memory state mismatch at key "
                f"{k}: got {s_memory_dict[k].shape} should be "
                f"{memory_dict[k].shape}"
            )

        assert s_read_vecs.shape == read_vecs.shape, (
            "Read state mismatch: got "
            f"{s_read_vecs.shape} should be "
            f"{read_vecs.shape}"
        )

    def build_dnc(self, device_idx: Union[int, None]) -> None:
        self.dnc = self.cfg["dnc_model"](
            input_size=self.cfg["preprocessor_output_size"],
            hidden_size=self.cfg["hidden_size"],
            num_layers=self.cfg["num_layers"],
            num_hidden_layers=self.cfg["num_hidden_layers"],
            read_heads=self.cfg["read_heads"],
            cell_size=self.cfg["cell_size"],
            nr_cells=self.cfg["nr_cells"],
            nonlinearity=self.cfg["nonlinearity"],
            gpu_id=device_idx,
        )

    def forward(
        self,
        input_dict: Dict[str, TensorType],
        state: List[TensorType],
        seq_lens: TensorType,
    ) -> Tuple[TensorType, List[TensorType]]:

        flat = input_dict["obs_flat"]
        # Batch and Time
        # Forward expects outputs as [B, T, logits]
        B = len(seq_lens)
        T = flat.shape[0] // B

        # Deconstruct batch into batch and time dimensions: [B, T, feats]
        flat = torch.reshape(flat, [-1, T] + list(flat.shape[1:]))

        # First run
        if self.dnc is None:
            gpu_id = flat.device.index if flat.device.index is not None else -1
            self.build_dnc(gpu_id)
            hidden = (None, None, None)

        else:
            hidden = self.unpack_state(state)  # type: ignore

        # Run thru preprocessor before DNC
        z = self.preprocessor(flat.reshape(B * T, self.obs_dim))
        z = z.reshape(B, T, self.cfg["preprocessor_output_size"])
        output, hidden = self.dnc(z, hidden)
        packed_state = self.pack_state(*hidden)

        # Compute action/value from output
        logits = self.logit_branch(output.view(B * T, -1))
        values = self.value_branch(output.view(B * T, -1))

        self.cur_val = values.squeeze(1)

        return logits, packed_state
