# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities associated with offloading functionality provided by `accelerate`.

| ------------------------------------------------------------------------------------------------------ | # noqa: E501
| Operation  | Without offloading support             | With offloading support                          | # noqa: E501
| ---------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501
| Add        | module.register_parameter(name, param) | register_offload_parameter(module, name, param)  | # noqa: E501
| Check      | N/A                                    | has_offloaded_params(module)                     | # noqa: E501
| Onload     | N/A                                    | with align_module_device(module)                 | # noqa: E501
| Update     | module.name.data.copy_(new_data)       | update_offload_parameter(module, name, new_data) | # noqa: E501
| Delete     | del module.name                        | delete_offload_parameter(module, name)           | # noqa: E501
| Add Module | module.register_module(name, child)    | register_offload_module(name, child)             | # noqa: E501
| Del Module | del module.name                        | delete_offload_module(module, name)              | # noqa: E501
| ------------------------------------------------------------------------------------------------------ | # noqa: E501
"""

import contextlib
import warnings
from functools import wraps
from operator import attrgetter
from typing import Any, Callable, Dict, Iterable, Literal, Optional, Tuple, Union

import torch
from compressed_tensors.utils import patch_attr


try:
    from accelerate.hooks import (
        AlignDevicesHook,
        add_hook_to_module,
        attach_align_device_hook,
        named_module_tensors,
        remove_hook_from_module,
    )
    from accelerate.utils import (
        OffloadedWeightsLoader,
        PrefixedDataset,
        find_tied_parameters,
        set_module_tensor_to_device,
    )

    _has_accelerate = True

except ImportError:
    _has_accelerate = False
    AlignDevicesHook = None
    add_hook_to_module = None
    remove_hook_from_module = None
    OffloadedWeightsLoader = None
    PrefixedDataset = None
    set_module_tensor_to_device = None
    named_module_tensors = None
    attach_align_device_hook = None
    find_tied_parameters = None


__all__ = [
    "get_execution_device",
    "get_offloaded_device",
    "update_parameter_data",
    "register_offload_parameter",
    "update_offload_parameter",
    "delete_offload_parameter",
    "has_offloaded_params",
    "disable_hf_hook",
    "disable_offload",
    "align_modules",
    "align_module_device",
    "register_offload_module",
    "delete_offload_module",
    "offloaded_dispatch",
    "disable_offloading",
    "remove_dispatch",
    "cast_to_device",
]


def check_accelerate(fallback: Any):
    def decorator(func: Callable[[Any], Any]):
        if not _has_accelerate:
            if fallback == "error":

                @wraps(func)
                def fallback_fn(*args, **kwargs):
                    raise ValueError(
                        "Please install `accelerate` in order to use this function"
                    )

            else:

                @wraps(func)
                def fallback_fn(*args, **kwargs):
                    return fallback

            return fallback_fn

        return func

    return decorator


""" Candidates for Depreciation """


def get_offloaded_device(module: torch.nn.Module) -> torch.device:
    """
    :param module: module to check
    :return: device module is offloaded to onto after forward pass
    """
    if has_offloaded_params(module):
        first_key = list(module._hf_hook.weights_map.keys())[0]
        prefix_dataset = module._hf_hook.weights_map.dataset
        return prefix_dataset[first_key].device
    else:
        # if the module is not offloaded, then any addded weights
        # should be placed the module's execution device
        return get_execution_device(module)


def update_parameter_data(
    module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str
):
    """
    Update the data of an existing parameter and its offload dict. Supports both
    parameters of offloaded modules and non-offloaded modules

    :param module: module containing the parameter to update
    :param new_param_data: tensor to update parameter with
    :param param_name: name of module parameter to update
    """
    update_offload_parameter(module, param_name, new_param_data)


""" Candidates for Upstreaming """


def cast_to_device(device_spec: Union[int, torch.device]) -> torch.device:
    """
    Convert an integer device index or torch.device into a torch.device object.

    :param device_spec: Device index (int) or torch.device object.
                        Negative integers map to CPU.
    :return: torch.device corresponding to the given device specification.
    """
    if isinstance(device_spec, int):
        return torch.device(f"cuda:{device_spec}" if device_spec >= 0 else "cpu")
    return device_spec


def get_execution_device(module: torch.nn.Module) -> torch.device:
    """
    Get the device which inputs should be moved to before module execution.
    Assume that modules execute in the same order as returned by `model.modules()`

    :param module: module to check, may be offloaded
    :return: onload device of module
    """
    for submodule in module.modules():
        if has_offloaded_params(submodule):
            return cast_to_device(submodule._hf_hook.execution_device)

        param = next(submodule.parameters(recurse=False), None)
        if param is not None:
            return param.device

    warnings.warn(f"Unable to get execution device of {module}, falling back to CPU")
    return torch.device("cpu")


def register_offload_parameter(
    module: torch.nn.Module,
    name: str,
    parameter: torch.nn.Parameter,
    offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
):
    """
    Register a parameter to the given module which may be offloaded

    :param module: maybe offloaded module
    :param name: name of newly registered parameter
    :param parameter: parameter being registered
    :param offload_device: device on which weight will be offloaded to. If None is
        provided, then infer device from parameters on module
    """
    has_onload = any(p.device != torch.device("meta") for p in module.parameters())
    module.register_parameter(name, parameter)

    # do everything AlignDevicesHook.init_hook does
    # https://github.com/huggingface/accelerate/blob/main/src/accelerate/hooks.py#L281
    if has_offloaded_params(module):
        hook: AlignDevicesHook = module._hf_hook
        assert hook.weights_map is not None

        # append to original_devices
        hook.original_devices[name] = parameter.device

        # append to weights map
        offload_to_weights_map(hook.weights_map, name, parameter.data, offload_device)

        # append to tied_params_map
        offloaded = hook.weights_map[name]
        if hook.tied_params_map is not None:
            hook.tied_params_map[offloaded.data_ptr()] = {}  # (1)

        # perform offloading
        if not has_onload:
            set_module_tensor_to_device(module, name, "meta")


def update_offload_parameter(
    module: torch.nn.Module,
    name: str,
    data: torch.Tensor,
    offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
):
    """
    Update the data of an existing parameter and its offload dict. Supports both
    parameters of offloaded modules and non-offloaded modules

    :param module: module containing the parameter to update
    :param name: name of module parameter to update
    :param data: tensor to update parameter with
    :param offload_device: device on which weight will be offloaded to. If None is
        provided, then infer device from parameters on module
    """
    param: torch.nn.Parameter = getattr(module, name)
    if param.data.shape != data.shape:
        warnings.warn(
            f"Shape of parameter being updated {param.data.shape} does not match shape "
            f"of update data {data.shape}"
        )

    # copy data into onloaded parameter if applicable
    if param.device != torch.device("meta") and data is not param.data:
        param.data.copy_(data)

    # update offload dict
    if has_offloaded_params(module):
        weights_map = module._hf_hook.weights_map
        offload_to_weights_map(weights_map, name, data, offload_device)


def delete_offload_parameter(module: torch.nn.Module, name: str):
    """
    Delete a parameter from a module which may be offloaded

    :param module: maybe offloaded module
    :param name: name of parameter being deleted
    """
    delattr(module, name)

    if has_offloaded_params(module):
        weights_map = module._hf_hook.weights_map
        delete_from_weights_map(weights_map, name)


@check_accelerate(fallback=contextlib.nullcontext())
@contextlib.contextmanager
def disable_hf_hook(module: torch.nn.Module):
    hooks = {}

    def collect_hooks(module):
        if hasattr(module, "_hf_hook"):
            hooks[module] = module._hf_hook
            remove_hook_from_module(module)

    module.apply(collect_hooks)

    yield

    for submodule, hook in hooks.items():
        add_hook_to_module(submodule, hook)


@check_accelerate(fallback=None)
def offload_to_weights_map(
    weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
    key: str,
    value: torch.Tensor,
    offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
):
    """
    Helper function which implements offloaded item assignment for PrefixedDataset,
    OffloadedWeightsLoader, and Dict types.

    :param weights_map: weight map to be updated with offload information
    :param key: key used to identify weight location
    :param value: weight being offloaded
    :param offload_device: device on which weight will be offloaded to. If None is
        provided, then infer device from parameters in weights_map
    """
    if isinstance(weights_map, PrefixedDataset):
        if offload_device == "disk":
            raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")

        dataset = weights_map.dataset
        key = f"{weights_map.prefix}{key}"
        offload_to_weights_map(dataset, key, value, offload_device)

    elif isinstance(weights_map, OffloadedWeightsLoader):
        if key not in weights_map.all_keys:
            weights_map.all_keys.append(key)

        if len(weights_map.index) <= 0 and offload_device != "disk":
            offload_to_weights_map(weights_map.state_dict, key, value, offload_device)

        else:
            raise NotImplementedError(
                "Updating weights_map with disk offloading is not implemented yet"
            )

    elif isinstance(weights_map, dict):
        if offload_device == "disk":
            raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")

        # infer offload device
        if offload_device is None:
            if key in weights_map:
                offload_device = weights_map[key].device
            else:
                tens = next(iter(weights_map.values()), None)
                if tens is None:
                    raise ValueError(
                        "Cannot infer offload device from empty weights_map"
                    )
                offload_device = tens.device

        weights_map[key] = value.to(device=offload_device)

    else:
        raise NotImplementedError(
            "Updating offload data not implemented for weights_map of type "
            f"{type(weights_map)}"
        )


@check_accelerate(fallback=None)
def delete_from_weights_map(
    weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
    key: str,
):
    if isinstance(weights_map, PrefixedDataset):
        dataset = weights_map.dataset
        key = f"{weights_map.prefix}{key}"
        delete_from_weights_map(dataset, key)

    elif isinstance(weights_map, OffloadedWeightsLoader):
        if len(weights_map.index) <= 0:
            delete_from_weights_map(weights_map.state_dict, key)

        else:
            raise NotImplementedError(
                "Delete from weights_map with disk offloading is not implemented yet"
            )

    elif isinstance(weights_map, dict):
        del weights_map[key]

    else:
        raise NotImplementedError(
            "Updating offload data not implemented for weights_map of type "
            f"{type(weights_map)}"
        )


@check_accelerate(fallback=contextlib.nullcontext())
@contextlib.contextmanager
def disable_offload(module: torch.nn.Module):
    """
    Context manager to disable module onloading and offloading. Parameters will stay on
    their current device

    :param module: module to disable offloading for
    """
    if has_offloaded_params(module):
        module._hf_hook.offload = False
        yield
        module._hf_hook.offload = True
    else:
        yield


@check_accelerate(fallback=contextlib.nullcontext())
@contextlib.contextmanager
def align_modules(
    modules: Union[torch.nn.Module, Iterable[torch.nn.Module]],
    execution_device: Optional[torch.device] = None,
):
    """
    Context manager for onloading modules to a device, and disabling onload and offload
    attempts triggered by forward calls. Used for sequential onloading of layers

    :param modules: `torch.nn.Module` or iterable of `torch.nn.Module`s to onload
    :param execution_device: device to onload to
    """
    modules = (modules,) if isinstance(modules, torch.nn.Module) else modules

    with contextlib.ExitStack() as stack:
        for module in modules:
            stack.enter_context(align_module_device(module, execution_device))
            stack.enter_context(disable_offload(module))  # disable redundant onloading
        yield


def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.Module):
    """
    Register a submodule with offloading if the parent module is offloaded

    :param base: module to attach submodule to
    :param name: name of submodule
    :param module: submodule to attach
    """

    if has_offloaded_params(base):
        hook: AlignDevicesHook = base._hf_hook
        assert hook.offload
        assert hook.weights_map is not None

        # offloading kwargs for submodule
        place_submodules = False
        offload_buffers = True

        # copy device offloading arguments from parent
        current_device = next(base.parameters()).device  # assume base has parameters
        offload_device = get_offloaded_device(base)

        # offload parameters to weights map
        for param_name, param in named_module_tensors(
            module, include_buffers=offload_buffers, recurse=place_submodules
        ):
            offloaded = param.to(offload_device)
            if hook.tied_params_map is not None:
                hook.tied_params_map[offloaded.data_ptr()] = {}  # (1)
            offload_to_weights_map(hook.weights_map, f"{name}.{param_name}", offloaded)

            # if the parent places submodules, offload here
            if hook.place_submodules:
                set_module_tensor_to_device(module, param_name, current_device)

        # if the parent does not place submodules, then add a hook
        # parameters are offloaded by `add_hook_to_module`
        if not hook.place_submodules:
            weights_map = PrefixedDataset(
                hook.weights_map.dataset, prefix=f"{hook.weights_map.prefix}{name}."
            )

            submodule_hook = AlignDevicesHook(
                execution_device=hook.execution_device,
                offload=hook.offload,
                io_same_device=False,
                weights_map=weights_map,
                offload_buffers=offload_buffers,
                place_submodules=place_submodules,
                skip_keys=None,
                tied_params_map=hook.tied_params_map,
            )
            add_hook_to_module(module, submodule_hook)

    base.register_module(name, module)


def delete_offload_module(base: torch.nn.Module, name: str):
    """
    Delete a submodule from a model which may contain offloading
    :param base: parent module to delete submodule from
    :param name: name of submodule on parent
    """
    module: torch.nn.Module = getattr(base, name)

    for param_name, _ in list(module.named_parameters()):
        delete_offload_parameter(module, param_name)

    delattr(base, name)


@check_accelerate(fallback="error")
def offloaded_dispatch(
    module: torch.nn.Module,
    execution_device: torch.device,
    offload_device: Union[torch.device, Literal["disk"]] = torch.device("cpu"),
) -> torch.nn.Module:
    """
    Unlike `dispatch_model`, this function forces a module (and its submodules) to
    offload all parameters and replace them with meta tensors, utiliizing the
    `AlignDevicesHook` to control onloading and offloading.

    :param module: module containing parameters to offload
    :param execution_device: device that modules will onload and execute on
    :param offload_device: device that module parameters will offload to
    :return: module with offloading device hooks
    """
    if offload_device == "disk":
        raise NotImplementedError("Disk offloading is not currently supported")

    # remove any existing hooks
    remove_dispatch(module)

    # create weights map
    state_dict = module.state_dict()
    state_dict = {key: val.to(offload_device) for key, val in state_dict.items()}
    weights_map = OffloadedWeightsLoader(state_dict=state_dict, device=offload_device)

    # create tied params map
    tied_params = find_tied_parameters(module)
    tied_params_map = {}
    for group in tied_params:
        for param_name in group:
            data_ptr = attrgetter(param_name)(module).data_ptr()
            tied_params_map[data_ptr] = {}

    # recursively attaches hooks to all submodules
    attach_align_device_hook(
        module,
        execution_device=execution_device,
        offload=True,
        weights_map=weights_map,
        tied_params_map=tied_params_map,
    )

    # when saving a model, `PretrainedModel.save_pretrained` will only
    # onload weights if the following requirements are met
    # if (
    #     hasattr(self, "hf_device_map")
    #     and len(set(self.hf_device_map.values())) > 1
    #     and ("cpu" in self.hf_device_map.values()
    #          or "disk" in self.hf_device_map.values())
    # ):
    # because this function always offloads, disregard actual devices and
    # always use `cpu` and `cuda:0` to guarantee this condition passes
    setattr(module, "hf_device_map", {"fake_offload": "cpu", "fake_exec": "cuda:0"})

    return module


def remove_dispatch(module: torch.nn.Module) -> torch.nn.Module:
    """
    Remove any existing dispatches from module

    :param module: module which may be dispatched with hf hooks
    :return: module without dispatch
    """
    remove_hook_from_module(module, recurse=True)
    if hasattr(module, "hf_device_map"):
        delattr(module, "hf_device_map")
    module.to("cpu")

    return module


@contextlib.contextmanager
def disable_offloading():
    """
    Keep modules onloaded and disable offloading until this context exits.
    Affects modules which have been hooked with accelerate's `AlignDevicesHook`
    """
    original_pre_forward = AlignDevicesHook.pre_forward
    onloaded_modules: Dict[torch.nn.Module, Tuple[AlignDevicesHook, bool]] = dict()

    # onload once and disable any future onloading/offloading steps
    def keep_onload_pre_forward(self: AlignDevicesHook, module, *args, **kwargs):
        ret = original_pre_forward(self, module, *args, **kwargs)
        if module not in onloaded_modules:
            onloaded_modules[module] = (self, self.offload)
            self.offload = False
        return ret

    # use the patched pre_forward function within the context
    with patch_attr(AlignDevicesHook, "pre_forward", keep_onload_pre_forward):
        yield

    # manually offload all modules that were onloaded
    # update any parameters which may have changed
    for module, (hook, offload) in onloaded_modules.items():
        hook.offload = offload
        for name, param in module.named_parameters(recurse=False):
            update_offload_parameter(module, name, param.data)
        hook.post_forward(module, None)


""" Upstreamed Functions """


# introduced in accelerate v1.1.0
@check_accelerate(fallback=False)
def has_offloaded_params(module: torch.nn.Module) -> bool:
    """
    Checks if a module has offloaded parameters by checking if the given module has a
    AlignDevicesHook attached with offloading enabled

    Args:
        module (`torch.nn.Module`): The module to check for an offload hook.

    Returns:
        bool: `True` if the module has an offload hook and offloading is enabled,
        `False` otherwise.
    """
    return (
        hasattr(module, "_hf_hook")
        and isinstance(module._hf_hook, AlignDevicesHook)
        and module._hf_hook.offload
    )


# introduced in accelerate v1.1.0
@check_accelerate(fallback=contextlib.nullcontext())
@contextlib.contextmanager
def align_module_device(
    module: torch.nn.Module, execution_device: Optional[torch.device] = None
):
    """
    Context manager that moves a module's parameters to the specified execution device.

    Args:
        module (`torch.nn.Module`):
            Module with parameters to align.
        execution_device (`torch.device`, *optional*):
            If provided, overrides the module's execution device within the context.
            Otherwise, use hook execution device or pass
    """
    if has_offloaded_params(module):
        if execution_device is not None:
            original_device = module._hf_hook.execution_device
            module._hf_hook.execution_device = execution_device

        try:
            module._hf_hook.pre_forward(module)
            yield
        finally:
            module._hf_hook.post_forward(module, None)
            if execution_device is not None:
                module._hf_hook.execution_device = original_device

    elif execution_device is not None:
        devices = {
            name: param.device for name, param in module.named_parameters(recurse=False)
        }
        try:
            for name in devices:
                set_module_tensor_to_device(module, name, execution_device)
            yield
        finally:
            for name, device in devices.items():
                set_module_tensor_to_device(module, name, device)

    else:
        yield


# (1): Since we cannot know which pointers are shared when we add parameters in an
# online way, assume that all pointers are shared. This has virtually no runtime cost
