# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from enum import Enum
from typing import Annotated, Any, Dict, List, Optional, Set, Union

from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs
from compressed_tensors.quantization.quant_scheme import (
    QuantizationScheme,
    preset_name_to_scheme,
)
from compressed_tensors.quantization.utils import is_module_quantized, module_type
from pydantic import BaseModel, ConfigDict, Field
from torch.nn import Module


__all__ = [
    "QuantizationStatus",
    "QuantizationConfig",
    "LIFECYCLE_ORDER",
    "DEFAULT_QUANTIZATION_METHOD",
    "DEFAULT_QUANTIZATION_FORMAT",
]


class QuantizationStatus(str, Enum):
    """
    Enum storing the different states a quantized layer can be in

    Initialized: scale, zero points and observers have been attached to the layer but
    are set to dummy values (not yet calibrated)
    Calibration: scale and zero points have been calibrated through OBCQ or similar
    algorithm, observers are still attached
    Frozen: scale and zero points are finalized, observers have been deleted, weights
    are still in their original precision
    Compressed: weights have been converted to their target type or compressed to
    their closed approximation
    """

    INITIALIZED = "initialized"
    CALIBRATION = "calibration"
    FROZEN = "frozen"
    COMPRESSED = "compressed"

    @classmethod
    def lifecycle_order(cls) -> List["QuantizationStatus"]:
        """
        :return: list of correct quantization lifecycle order
        """
        return

    def __ge__(self, other):
        if other is None:
            return True
        if not isinstance(other, self.__class__):
            raise NotImplementedError
        return LIFECYCLE_ORDER.index(self) >= LIFECYCLE_ORDER.index(other)

    def __gt__(self, other):
        if other is None:
            return True
        if not isinstance(other, self.__class__):
            raise NotImplementedError
        return LIFECYCLE_ORDER.index(self) > LIFECYCLE_ORDER.index(other)

    def __lt__(self, other):
        if other is None:
            return False
        if not isinstance(other, self.__class__):
            raise NotImplementedError
        return LIFECYCLE_ORDER.index(self) < LIFECYCLE_ORDER.index(other)

    def __le__(self, other):
        if other is None:
            return False
        if not isinstance(other, self.__class__):
            raise NotImplementedError
        return LIFECYCLE_ORDER.index(self) <= LIFECYCLE_ORDER.index(other)


LIFECYCLE_ORDER = [
    QuantizationStatus.INITIALIZED,
    QuantizationStatus.CALIBRATION,
    QuantizationStatus.FROZEN,
    QuantizationStatus.COMPRESSED,
]

DEFAULT_QUANTIZATION_METHOD = "compressed-tensors"
DEFAULT_QUANTIZATION_FORMAT = "fakequant"


class QuantizationConfig(BaseModel):
    """
    Full configuration specifying how a model is quantized. Each quantized layer is
    mapped to a QuantizationScheme in config_groups.

    :param config_groups: dict of QuantizationSchemes specifying the quantization
    settings for each quantized layer. A group could also be a reference to
    a predefined scheme name, mapped to a list of its target layers/classes
    :param quant_method: a constant used to differentiate compressed-tensors
    quantization from other quantization configs
    :param format: specifies how the quantized model is stored on disk
    :quantization_status: specifies the current status of all quantized layers. It is
        assumed all layers are in the same state.
    :param kv_cache_scheme: optional QuantizationArgs, that specify the
        quantization of the kv cache. If None, kv cache is not quantized.
        When applying kv cache quantization to transformer AutoModelForCausalLM,
        the kv_cache_scheme gets converted into a QuantizationScheme that:
            - targets the `q_proj` and `k_proj` modules of the model. The outputs
              of those modules are the keys and values that might be cached
            - quantizes the outputs of the aformentioned layers, so that
              keys and values are compressed before storing them in the cache
        There is an explicit assumption that the model contains modules with
        `k_proj` and `v_proj` in their names. If this is not the case
        and kv_cache_scheme != None, the quantization of kv cache will fail
    :global_compression_ratio: optional informational config to report the model
        compression ratio acheived by the quantization config
    :ignore: optional list of layers to ignore from config_groups. Layers in this list
        are not quantized even if they match up with a target in config_groups
    """

    config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
    quant_method: str = DEFAULT_QUANTIZATION_METHOD
    kv_cache_scheme: Optional[QuantizationArgs] = None
    format: str = DEFAULT_QUANTIZATION_FORMAT
    quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
    global_compression_ratio: Optional[float] = None
    ignore: Optional[List[str]] = Field(default_factory=list)
    # `run_compressed` is a dummy, unused arg for backwards compatibility
    # see: https://github.com/huggingface/transformers/pull/39324
    run_compressed: Annotated[Any, Field(exclude=True)] = None

    def model_post_init(self, __context):
        """
        updates any quantization schemes defined as presets to be fully loaded
        schemes
        """
        for group_name, targets_or_scheme in self.config_groups.items():
            if isinstance(targets_or_scheme, QuantizationScheme):
                continue  # scheme already defined
            self.config_groups[group_name] = preset_name_to_scheme(
                name=group_name,
                targets=targets_or_scheme,
            )

    def to_dict(self):
        # for compatibility with HFQuantizer
        return self.model_dump()

    @staticmethod
    def from_pretrained(
        model: Module, format: Optional[Union[str, list]] = None
    ) -> Optional["QuantizationConfig"]:
        """
        Converts a model into its associated QuantizationConfig based on the
        QuantizationScheme attached to each quantized module

        :param model: model to calculate quantization scheme of
        :return: filled out QuantizationScheme for the input model
        """
        from compressed_tensors.modeling import IMPL_ATTR
        from compressed_tensors.quantization.lifecycle.initialize import (
            is_attention_module,
        )

        # set of all quantization schemes
        # TODO: make quant config/scheme/args frozen/hashable and use a set
        quantization_schemes: List[QuantizationScheme] = list()

        # use any status from modules (in practice, use the last module)
        model_status = None

        # set of all quantized types
        # this is later used to create the ignore list
        quantization_type_names: Set[str] = set()

        # maps types to names which are not quantized
        # this is later used to create the ignore list
        ignore: Dict[str, List[str]] = defaultdict(list)

        # this keeps track of any kvcache schemes
        kv_cache_scheme: Optional[QuantizationArgs] = None

        for name, submodule in model.named_modules():
            layer_type: str = module_type(submodule)

            # add config group if quantized non-attention or attention quant
            has_config_group = is_module_quantized(submodule) and (
                not is_attention_module(submodule) or hasattr(submodule, IMPL_ATTR)
            )
            # only add kvcache if quant attention (which always implies kvcache)
            has_kv_cache = is_module_quantized(submodule) and is_attention_module(
                submodule
            )

            if has_config_group:
                # add to running set of schemes/layer_type_names
                model_status = getattr(submodule, "quantization_status", model_status)
                quantization_type_names.add(layer_type)
                if submodule.quantization_scheme not in quantization_schemes:
                    quantization_schemes.append(submodule.quantization_scheme)

            if has_kv_cache:
                model_status = getattr(submodule, "quantization_status", model_status)
                kv_cache_scheme = submodule.quantization_scheme.input_activations

            if not has_config_group:
                # add non-quantized layers to the ignore list
                if layer_type not in ignore:
                    ignore[layer_type] = []
                ignore[layer_type].append(name)

        if (
            len(quantization_schemes) == 0 and kv_cache_scheme is None
        ):  # No quantized layers
            return None

        # create ignore list, only include layers whose class has ever been targeted
        consolidated_ignore = []
        for layer_type, ignore_names in ignore.items():
            if layer_type in quantization_type_names:
                # specific layers of a quantized type are ignored
                consolidated_ignore += ignore_names
            # else we leave it off the ignore list, doesn't fall under any of the
            # existing quantization schemes so it won't be quantized

        # create config groups from all unique schemes
        config_groups = {}
        for idx, scheme in enumerate(quantization_schemes):
            group_name = "group_" + str(idx)
            config_groups[group_name] = scheme

        # infer format
        if format is None:
            if model_status == QuantizationStatus.COMPRESSED:
                format = CompressionFormat.int_quantized.value
            else:
                format = CompressionFormat.dense.value
        elif isinstance(format, list):
            format = (
                CompressionFormat.mixed_precision.value
                if len(format) > 1
                else format[0]
            )

        return QuantizationConfig(
            config_groups=config_groups,
            quantization_status=model_status,
            kv_cache_scheme=kv_cache_scheme,
            global_compression_ratio=None,
            format=format,
            ignore=consolidated_ignore,
        )

    def requires_calibration_data(self):
        if self.kv_cache_scheme is not None:
            return True

        for _, scheme in self.config_groups.items():
            if scheme.input_activations is not None:
                if scheme.input_activations.dynamic in (False, DynamicType.LOCAL):
                    return True
            if scheme.output_activations is not None:
                if not scheme.output_activations.dynamic:
                    return True

        return False

    # TODO set `extra="forbid"` when upstream transformers is compatible
    model_config = ConfigDict(extra="ignore")
