# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import warnings
from functools import wraps
from types import MappingProxyType
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    Mapping,
    Optional,
    TypeVar,
)

import numpy
import torch
from transformers import AutoConfig, PretrainedConfig


T = TypeVar("T", bound="Callable")  # used by `deprecated`


if TYPE_CHECKING:
    from compressed_tensors.compressors import ModelCompressor


__all__ = [
    "infer_compressor_from_model_config",
    "fix_fsdp_module_name",
    "tensor_follows_mask_structure",
    "replace_module",
    "is_compressed_tensors_config",
    "getattr_chain",
    "deprecated",
    "Aliasable",
    "combine_shards",
    "shard_tensor",
    "pack_bitmasks",
    "unpack_bitmasks",
    "patch_attr",
    "patch_attrs",
    "ParameterizedDefaultDict",
    "get_num_attn_heads",
    "get_num_kv_heads",
    "get_head_dim",
]

FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"


def infer_compressor_from_model_config(
    pretrained_model_name_or_path: str,
) -> Optional["ModelCompressor"]:  # noqa: F821
    """
    Given a path to a model config, extract a sparsity config if it exists and return
    the associated ModelCompressor

    :param pretrained_model_name_or_path: path to model config on disk or HF hub
    :return: matching compressor if config contains a sparsity config
    """
    from compressed_tensors.compressors import ModelCompressor
    from compressed_tensors.config import CompressionConfig

    config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
    sparsity_config = ModelCompressor.parse_sparsity_config(config)
    if sparsity_config is None:
        return None

    format = sparsity_config.get("format")
    sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config)
    compressor = ModelCompressor.load_from_registry(format, config=sparsity_config)
    return compressor


def fix_fsdp_module_name(name: str) -> str:
    """
    Remove FSDP wrapper prefixes from a module name
    Accounts for scenario where FSDP_WRAPPER_NAME is
    at the end of the name, as well as in the middle.
    :param name: name to strip
    :return: stripped name
    """
    return name.replace(FSDP_WRAPPER_NAME + ".", "").replace(
        "." + FSDP_WRAPPER_NAME, ""
    )


def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool:
    """
    :param tensor: tensor to check
    :param mask: mask structure to check for, in the format "n:m"
    :return: True if the tensor follows the mask structure, False otherwise.
        Note, some weights can incidentally be zero, so we check for
        atleast n zeros in each chunk of size m
    """

    n, m = tuple(map(int, mask.split(":")))
    # Reshape the tensor into chunks of size m
    tensor = tensor.view(-1, m)

    # Count the number of zeros in each chunk
    zero_counts = (tensor == 0).sum(dim=1)

    # Check if the number of zeros in each chunk atleast n
    # Greater than sign is needed as some weights can incidentally
    # be zero
    if not torch.all(zero_counts >= n).item():
        raise ValueError()

    return True


def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Module):
    if "." in name:
        parent_name = name.rsplit(".", 1)[0]
        child_name = name[len(parent_name) + 1 :]
        parent = model.get_submodule(parent_name)
    else:
        parent_name = ""
        parent = model
        child_name = name
    setattr(parent, child_name, new_module)


def is_compressed_tensors_config(compression_config: Any) -> bool:
    """
    Returns True if CompressedTensorsConfig is available from transformers and
    compression_config is an instance of CompressedTensorsConfig

    See: https://github.com/huggingface/transformers/pull/31704
    """
    try:
        from transformers.utils.quantization_config import CompressedTensorsConfig

        return isinstance(compression_config, CompressedTensorsConfig)
    except ImportError:
        return False


def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
    """
    Chain multiple getattr calls, separated by `.`

    :param obj: base object whose attributes are being retrieved
    :param chain_str: attribute names separated by `.`
    :param default: default value, throw error otherwise
    """
    if len(args) >= 1:
        has_default = True
        default = args[0]
    elif "default" in kwargs:
        has_default = True
        default = kwargs["default"]
    else:
        has_default = False

    attr_names = chain_str.split(".")

    res = obj
    for attr_name in attr_names:
        if not hasattr(res, attr_name):
            if has_default:
                return default
            else:
                raise AttributeError(f"{res} object has no attribute {attr_name}")
        res = getattr(res, attr_name)

    return res


def deprecated(
    future_name: Optional[str] = None, message: Optional[str] = None
) -> Callable[[T], T]:
    """
    Decorator to mark functions as deprecated

    :param new_function: Function called in place of deprecated function
    :param message: Deprecation message, replaces default deprecation message
    """

    def decorator(func: T) -> T:
        nonlocal message

        if message is None:
            message = (
                f"{func.__name__} is deprecated and will be removed in a future release"
            )
            if future_name is not None:
                message += f". Please use {future_name} instead."

        @wraps(func)
        def wrapped(*args, **kwargs):
            warnings.warn(message, DeprecationWarning, stacklevel=2)
            return func(*args, **kwargs)

        return wrapped

    return decorator


class Aliasable:
    """
    A mixin for enums to allow aliasing of enum members

    Example:
    >>> class MyClass(Aliasable, int, Enum):
    >>>     ...
    """

    @staticmethod
    def get_aliases() -> Dict[str, str]:
        raise NotImplementedError()

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            aliases = self.get_aliases()
            return self.value == other.value or (
                aliases.get(self.value, self.value)
                == aliases.get(other.value, other.value)
            )
        else:
            aliases = self.get_aliases()
            self_value = aliases.get(self.value, self.value)
            other_value = aliases.get(other, other)
            return self_value == other_value

    def __hash__(self):
        canonical_value = self.aliases.get(self.value, self.value)
        return hash(canonical_value)


def shard_tensor(
    tensor: torch.Tensor, shard_sizes: List[int], dim: int = 0
) -> List[torch.Tensor]:
    """
    Shards a tensor into a list of tensors along a given dimension.

    raises: ValueError: If the sum of shard_sizes does not match the
        size of the tensor along the given dimension.

    :param tensor: The input tensor to shard.
    :param shard_sizes : List of sizes for each shard along the specified dimension.
    :param dim : The dimension along which to shard the tensor.
    :returns: A list of tensors sharded along the specified dimension.
    """
    if sum(shard_sizes) != tensor.size(dim):
        raise ValueError(
            "Sum of shard_sizes must equal the size of the tensor "
            "along the specified dimension."
        )

    shards = []
    start_idx = 0

    for size in shard_sizes:
        end_idx = start_idx + size
        shard = tensor.narrow(dim, start_idx, size)
        shards.append(shard)
        start_idx = end_idx

    return shards


def combine_shards(shards, dim=0):
    """
    Combine decompressed shards along a given dimension using `narrow`.

    :param shards: List of decompressed shard tensors.
    :param dim: Dimension to combine along (default: 0).
    :return: Combined decompressed tensor.
    """
    if not shards:
        raise ValueError("The list of shards is empty.")

    # Assert that all shards have the same dtype
    shard_dtypes = {shard.dtype for shard in shards}
    if len(shard_dtypes) > 1:
        raise ValueError("All shards must have the same dtype.")

    # Determine the total shape of the combined tensor
    total_shape = list(shards[0].shape)
    total_shape[dim] = sum(shard.shape[dim] for shard in shards)

    # Create the combined tensor
    combined = torch.zeros(total_shape, dtype=shards[0].dtype, device=shards[0].device)

    # Fill the combined tensor using narrow
    shard_offset = 0
    for shard in shards:
        shard_size = shard.shape[dim]
        combined.narrow(dim, shard_offset, shard_size).copy_(shard)
        shard_offset += shard_size

    return combined


def pack_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor:
    """
    Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
    compressed to R x ceil(C/8)

    :param bytemasks: mask tensor where each byte corresponds to a weight
    :return: mask tensor where each bit corresounds to a weight
    """
    packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
    packed_bits_torch = torch.from_numpy(packed_bits_numpy)

    return packed_bits_torch


def unpack_bitmasks(
    packed_bitmasks: torch.Tensor, original_shape: List[int]
) -> torch.Tensor:
    """
    Converts a bitmask tensor back to a bytemask tensor for use during decompression

    :param packed_bitmasks: mask tensor where each bit corresponds to a weight
    :param original_shape: dense shape to decompress to
    :return: boolean mask of weights in the original dense shape
    """
    # Unpack the bits
    unpacked_bits = numpy.unpackbits(
        packed_bitmasks.cpu().numpy(),
        axis=-1,
        count=original_shape[-1],
        bitorder="little",
    )

    # Reshape to match the original shape
    unpacked_bitmasks_torch = torch.from_numpy(
        unpacked_bits.reshape(original_shape).astype(bool)
    )

    return unpacked_bitmasks_torch


@contextlib.contextmanager
def patch_attr(base: object, attr: str, value: Any):
    """
    Patch the value of an object attribute. Original value is restored upon exit

    :param base: object which has the attribute to patch
    :param attr: name of the the attribute to patch
    :param value: used to replace original value

    Usage:
    >>> from types import SimpleNamespace
    >>> obj = SimpleNamespace()
    >>> with patch_attr(obj, "attribute", "value"):
    ...     assert obj.attribute == "value"
    >>> assert not hasattr(obj, "attribute")
    """
    _sentinel = object()
    original_value = getattr(base, attr, _sentinel)

    setattr(base, attr, value)
    try:
        yield
    finally:
        if original_value is not _sentinel:
            setattr(base, attr, original_value)
        else:
            delattr(base, attr)


@contextlib.contextmanager
def patch_attrs(bases: Iterable[Any], attr: str, values: Iterable[Any]):
    """
    Same as `patch_attr` but for a list of objects to patch
    Patch attribute for a list of objects with list of values.
    Original values are restored upon exit

    :param bases: objects which has the attribute to patch
    :param attr: name of the the attribute to patch
    :param values: used to replace original values. Must be same
        length as bases

    Usage:
    >>> from types import SimpleNamespace
    >>> obj1 = SimpleNamespace()
    >>> obj2 = SimpleNamespace()
    >>> with patch_attr([obj1, obj2], "attribute", ["value1", "value2"]):
    ...     assert obj1.attribute == "value1"
    ...     assert obj2.attribute == "value2"
    >>> assert not hasattr(obj1, "attribute")
    >>> assert not hasattr(obj2, "attribute")
    """
    with contextlib.ExitStack() as stack:
        for base, value in zip(bases, values):
            stack.enter_context(patch_attr(base, attr, value))
        yield


class ParameterizedDefaultDict(dict):
    """
    Similar to `collections.DefaultDict`, but upon fetching a key which is missing,
    the key is passed as arguments to the `default_factory`

    :param default_factory: function which takes a key as input and returns the
        corresponding default value
    """

    def __init__(self, default_factory: Callable[[Any], Any]):
        self.default_factory = default_factory
        self._factory_kwargs = MappingProxyType({})

    def __missing__(self, key: Any) -> Any:
        if isinstance(key, tuple):
            value = self.default_factory(*key, **self._factory_kwargs)
        else:
            value = self.default_factory(key, **self._factory_kwargs)
        self[key] = value
        return value

    def get(self, *args, factory_kwargs: Mapping = MappingProxyType({})) -> Any:
        """
        Similar to `__getitem__`, but allows passing kwargs to factory function

        :param \\*args: args whose tuple will value will be treated as key
        :param factory_kwargs: keyword arguments to pass to `default_factory`
        :return: dictionary entry for given key
        """
        with patch_attr(self, "_factory_kwargs", factory_kwargs):
            return self[args]


def get_num_attn_heads(config: PretrainedConfig) -> int:
    """
    Get the number of attention heads used by a model

    :param config: model config
    :return: num_attention_heads of model
    """
    if hasattr(config, "num_attention_heads"):
        return config.num_attention_heads

    elif hasattr(config, "hidden_size") and hasattr(config, "head_dim"):
        return config.hidden_size // config.head_dim

    else:
        raise ValueError(
            "Cannot determine num_attention_heads from config. Config must define "
            "either `num_attention_heads` or both `hidden_size` and `head_dim`. "
            f"{config}"
        )


def get_num_kv_heads(config: PretrainedConfig) -> int:
    """
    Get the number of key-value attention heads used by a model

    :param config: model config
    :return: num_key_value_heads of model
    """
    if hasattr(config, "num_key_value_heads"):
        return config.num_key_value_heads

    else:
        raise ValueError(
            "Cannot determine num_key_value_heads from config. Config must define "
            f"`num_key_value_heads`. {config}"
        )


def get_head_dim(config: PretrainedConfig) -> int:
    """
    Get the number of dimensions used by the attention heads of a model

    :param config: model config
    :return: head_dim of model
    """
    if hasattr(config, "head_dim"):
        return config.head_dim

    elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"):
        return config.hidden_size // config.num_attention_heads

    else:
        raise ValueError(
            "Cannot determine head_dim from config. Config must define "
            "either `head_dim` or both `hidden_size` and `num_attention_heads`. "
            f"{config}"
        )
