# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import logging
from typing import Optional, Tuple, Union

import torch
from compressed_tensors.modeling import (
    IMPL_ATTR,
    KV_CACHE_ATTR,
    QuantizedAttentionImpl,
    QuantizedKVCache,
)
from compressed_tensors.quantization import (
    ActivationOrdering,
    DynamicType,
    QuantizationArgs,
    QuantizationMetadata,
    QuantizationScheme,
    QuantizationStatus,
    QuantizationStrategy,
)
from compressed_tensors.quantization.lifecycle.forward import (
    wrap_module_forward_quantized,
)
from compressed_tensors.quantization.utils import strategy_cdiv
from compressed_tensors.utils import (
    disable_hf_hook,
    get_execution_device,
    get_head_dim,
    get_num_attn_heads,
    get_num_kv_heads,
    register_offload_parameter,
)
from torch.nn import Module, Parameter


__all__ = [
    "initialize_module_for_quantization",
    "is_attention_module",
    "initialize_qparams",
    "initialize_attn_qparams",
]


_LOGGER = logging.getLogger(__name__)


def initialize_module_for_quantization(
    module: Module,
    scheme: Optional[QuantizationScheme] = None,
    force_zero_point: bool = True,
):
    """
    Attaches appropriate scales, zero points, and observers to a layer
    given its target quantization scheme.

    Previously initialized scales and zero points will be removed from
    module if they no longer apply to the scheme

    :param module: module to set for calibration
    :param scheme: scheme to use for quantization. if None is provided,
        will attempt to use scheme stored in the module under `quantization_scheme`,
        if not provided, the layer will be skipped
    :param force_zero_point: whether to force initialization of a zero point for
        symmetric quantization
    """
    scheme = scheme or getattr(module, "quantization_scheme", None)
    if scheme is None:
        return

    QuantizationMetadata.clear_all_qparams(module)

    if is_attention_module(module):
        # quantized actions based on calltime status
        initialize_attn_qparams(module, scheme, force_zero_point)

    else:
        if not isinstance(module, torch.nn.Linear):
            _LOGGER.warning(f"Attempting to quantize module of type {type(module)}")

        # use weight to determine observed shapes and dtype
        if hasattr(module, "weight"):
            weight = module.weight
            assert isinstance(weight, torch.Tensor)
        else:
            # Note that a weight is required for both weight and activation
            # quantization in order to know the dtype of activation scales
            _LOGGER.warning(
                f"module type {type(module)} targeted for quantization but "
                f"has no attribute weight, skipping quantization for {type(module)}"
            )
            return

        if scheme.input_activations is not None:
            initialize_qparams(
                module,
                "input",
                scheme.input_activations,
                observed_shape=weight.shape[-1:],
                observed_dtype=weight.dtype,
                force_zero_point=force_zero_point,
            )

        if scheme.weights is not None:
            initialize_qparams(
                module,
                "weight",
                scheme.weights,
                observed_shape=weight.shape,
                observed_dtype=weight.dtype,
                force_zero_point=force_zero_point,
            )

        if scheme.output_activations is not None:
            initialize_qparams(
                module,
                "output",
                scheme.output_activations,
                observed_shape=weight.shape[:-1],
                observed_dtype=weight.dtype,
                force_zero_point=force_zero_point,
            )

        with disable_hf_hook(module):
            # wrap forward call of module to perform
            # quantized actions based on calltime status
            wrap_module_forward_quantized(module, scheme)

    module.quantization_scheme = scheme
    module.quantization_status = QuantizationStatus.INITIALIZED


def is_attention_module(module: Module):
    return "attention" in module.__class__.__name__.lower() and (
        hasattr(module, "k_proj")
        or hasattr(module, "v_proj")
        or hasattr(module, "qkv_proj")
    )


def initialize_qparams(
    module: Module,
    base_name: str,
    quantization_args: QuantizationArgs,
    observed_shape: Tuple[Union[int, None]],
    observed_dtype: torch.dtype,
    force_zero_point: bool = True,
):
    """
    Initialize quantization parameters for a given basename according to the passed
    quantization args. The shape and dtype of the observed weight/activation must also
    be provided.

    Scales will always be initialized. Global scales are initialized depending on args.
    Zero points will be initialized if not symmetric or if `force_zero_point` is True.

    :param module: module to register qparams to
    :param base_name: base name of qparams, for example "input", "weight", "k", "v"
    :param quantization_args: arguments for quantization
    :param observed_shape: last (right-most) known dimensions of the observed weight/act
    :param observed_dtype: dtype of the observed weight/actt
    :param force_zero_point: force the zero_point parameter to be initialized
    """
    strategy = quantization_args.strategy
    dynamic = quantization_args.dynamic
    actorder = quantization_args.actorder
    device = get_execution_device(module)  # avoid performing intialization ops on cpu

    # Skip all intialization for fully dynamic quantization
    if dynamic is True:
        return

    # 0. Create global scale for tensor-group quantization
    if strategy == QuantizationStrategy.TENSOR_GROUP:
        init_global_scale = Parameter(
            torch.empty(1, dtype=torch.float32, device=device),
            requires_grad=False,
        )
        register_offload_parameter(
            module, f"{base_name}_global_scale", init_global_scale
        )

    # Skip scale/zp initialization for locally dynamic quantization
    if dynamic == DynamicType.LOCAL:
        return

    # 1. Infer expected scale/zp shape
    if strategy == QuantizationStrategy.TENSOR:
        expected_shape = (1,)

    elif strategy == QuantizationStrategy.TOKEN:
        raise ValueError("Cannot perform static token quantization")

    elif strategy == QuantizationStrategy.CHANNEL:
        if len(observed_shape) < 2:
            raise ValueError("Channel quant requires at least 2 observed dimensions")

        expected_shape = (observed_shape[-2], 1)

    elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
        assert quantization_args.group_size is not None
        if len(observed_shape) < 1:
            raise ValueError("Group quant requires at least 1 observed dimension")

        group_size = quantization_args.group_size
        num_groups = strategy_cdiv(observed_shape[-1], group_size, strategy)
        expected_shape = (*observed_shape[:-1], num_groups)

        # initialize activation ordering if applicable
        if actorder == ActivationOrdering.GROUP:
            init_g_idx = Parameter(
                torch.full((observed_shape[-1],), -1, device=device, dtype=torch.int),
                requires_grad=False,
            )
            register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)

    elif strategy == QuantizationStrategy.BLOCK:
        assert quantization_args.block_structure is not None
        if len(observed_shape) < 2:
            raise ValueError("Block quant requires at least 2 observed dimensions")

        block_structure = quantization_args.block_structure
        num_rows = strategy_cdiv(observed_shape[-2], block_structure[-2], strategy)
        num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy)
        expected_shape = (num_rows, num_cols)

    elif strategy == QuantizationStrategy.ATTN_HEAD:
        # (batch_size, num_attention_heads, seq_len, head_dim)
        if len(observed_shape) < 3:
            raise ValueError("Attention quant requires at least 3 observed dimensions")

        expected_shape = (observed_shape[-3], 1, 1)

    else:
        assert False, f"Unknown strategy {strategy}"

    # 2. Identify quantization scale and zp dtype
    scale_dtype = observed_dtype
    if scale_dtype not in [
        torch.float16,
        torch.bfloat16,
        torch.float32,
        torch.float64,
    ]:
        scale_dtype = torch.float16

    # 3. Initializes scale/zp for the module
    init_scale = Parameter(
        torch.empty(expected_shape, dtype=scale_dtype, device=device),
        requires_grad=False,
    )
    register_offload_parameter(module, f"{base_name}_scale", init_scale)

    if force_zero_point or not quantization_args.symmetric:
        init_zero_point = Parameter(
            torch.zeros(
                expected_shape, device=device, dtype=quantization_args.zp_dtype
            ),
            requires_grad=False,
        )
        register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)


def initialize_attn_qparams(
    module: Module, scheme: QuantizationScheme, force_zero_point: bool
):
    """Initlaize k_scale, v_scale for self_attn"""

    impl: Optional[QuantizedAttentionImpl] = getattr(module, IMPL_ATTR, None)
    kv_cache: Optional[QuantizedKVCache] = getattr(module, KV_CACHE_ATTR, None)

    if impl is None and kv_cache is None:
        raise ValueError(
            f"Attention module has quantization scheme but no {IMPL_ATTR} "
            f"or {KV_CACHE_ATTR} attributes. Please ensure that these "
            "attributes are initialized using `apply_quantization_config`."
        )

    _validate_attention_scheme(scheme)

    # extract shapes from config
    config = kv_cache.config
    num_attn_heads = get_num_attn_heads(config)
    num_kv_heads = get_num_kv_heads(config)
    head_dim = get_head_dim(config)

    # (batch_size, num_heads, slen, head_dim)
    q_observed_shape = (num_attn_heads, None, head_dim)
    kv_observed_shape = (num_kv_heads, None, head_dim)
    observed_dtype = next(module.parameters()).dtype

    if impl is not None:
        initialize_qparams(
            module,
            "q",
            scheme.input_activations,
            observed_shape=q_observed_shape,
            observed_dtype=observed_dtype,
            force_zero_point=force_zero_point,
        )

    if kv_cache is not None:
        initialize_qparams(
            module,
            "k",
            scheme.input_activations,
            observed_shape=kv_observed_shape,
            observed_dtype=observed_dtype,
            force_zero_point=force_zero_point,
        )
        initialize_qparams(
            module,
            "v",
            scheme.input_activations,
            observed_shape=kv_observed_shape,
            observed_dtype=observed_dtype,
            force_zero_point=force_zero_point,
        )


def _validate_attention_scheme(scheme: QuantizationScheme):
    if scheme.weights is not None:
        raise ValueError(
            "Cannot apply weight quantization to attention. "
            "Instead, target the (q|k|v)_proj submodule layers of attention"
        )

    if scheme.input_activations is None:
        raise ValueError(
            "Cannot apply attention quantization without specifying input activations"
        )

    if scheme.output_activations is not None:
        raise ValueError("Cannot apply output quantization to attention")
