# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import wraps
from math import ceil
from typing import Optional

import torch
from compressed_tensors.quantization.quant_args import (
    DynamicType,
    QuantizationArgs,
    QuantizationStrategy,
    round_to_quantized_type_args,
)
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.utils import (
    calculate_range,
    compute_dynamic_scales_and_zp,
)
from torch.nn import Module


__all__ = [
    "quantize",
    "dequantize",
    "fake_quantize",
    "wrap_module_forward_quantized",
    "forward_quantize",
]


@torch.no_grad()
def quantize(
    x: torch.Tensor,
    scale: torch.Tensor,
    zero_point: torch.Tensor,
    args: QuantizationArgs,
    dtype: Optional[torch.dtype] = None,
    g_idx: Optional[torch.Tensor] = None,
    global_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Quantize the input tensor x using the QuantizationStrategy specified in args.
    Quantization can be done per tensor, channel, token or group. For group
    quantization, the group_size must be divisible by the column size. The input scale
    and zero_points are reshaped to support vectorization (Assumes 1 is the
    channel dimension)

    :param x: Input tensor
    :param scale: scale tensor
    :param zero_point: zero point tensor
    :param args: quantization args dictating how to quantize x
    :param dtype: optional dtype to cast the quantized output to
    :param g_idx: optional mapping from column index to group index
    :param global_scale: optional constant to scale the quantization scale during QDQ
    :return: fake quantized tensor
    """

    return _process_quantization(
        x=x,
        scale=scale,
        zero_point=zero_point,
        args=args,
        dtype=dtype,
        do_quantize=True,
        do_dequantize=False,
        g_idx=g_idx,
        global_scale=global_scale,
    )


@torch.no_grad()
def dequantize(
    x_q: torch.Tensor,
    scale: torch.Tensor,
    zero_point: Optional[torch.Tensor] = None,
    args: Optional[QuantizationArgs] = None,
    dtype: Optional[torch.dtype] = None,
    g_idx: Optional[torch.Tensor] = None,
    global_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Dequantize a quantized input tensor x_q based on the strategy specified in args. If
    args is not provided, the strategy will be inferred.

    :param x: quantized input tensor
    :param scale: scale tensor
    :param zero_point: zero point tensor
    :param args: quantization args used to quantize x_q
    :param dtype: optional dtype to cast the dequantized output to
    :param g_idx: optional mapping from column index to group index
    :param global_scale: optional constant to scale the quantization scale during QDQ
    :return: dequantized float tensor
    """
    if args is None:
        if scale.ndim == 0 or scale.ndim == 1:
            args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)
        elif scale.ndim == 2:
            if scale.shape[1] == 1:
                args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
            # Scale height matches input or is 1 -> group quantization across columns
            #
            # Example 1: scale.shape[0] == 1
            # x_q: (4, 8), scale: (1, 4) -> 2 columns per group
            #
            # Example 2: scale.shape[0] == x_q.shape[0]
            # x_q: (4, 8), scale: (4, 4) -> 2 elements per group (per row)
            elif (scale.shape[0] == 1) or (scale.shape[0] == x_q.shape[0]):
                group_size = int(x_q.shape[1] / scale.shape[1])
                args = QuantizationArgs(
                    strategy=QuantizationStrategy.GROUP, group_size=group_size
                )
            else:
                rows, cols = x_q.shape[-2], x_q.shape[-1]
                block_height = rows // scale.shape[0]  # Rows per block
                block_width = cols // scale.shape[1]  # Columns per block

                args = QuantizationArgs(
                    strategy=QuantizationStrategy.BLOCK,
                    block_structure=[block_height, block_width],
                )
        else:
            raise ValueError(
                f"Could not infer a quantization strategy from scale with {scale.ndim} "
                "dimmensions. Expected 0 or 2 dimmensions."
            )

    if dtype is None:
        dtype = scale.dtype

    return _process_quantization(
        x=x_q,
        scale=scale,
        zero_point=zero_point,
        args=args,
        do_quantize=False,
        do_dequantize=True,
        dtype=dtype,
        g_idx=g_idx,
        global_scale=global_scale,
    )


@torch.no_grad()
def fake_quantize(
    x: torch.Tensor,
    scale: torch.Tensor,
    zero_point: torch.Tensor,
    args: QuantizationArgs,
    g_idx: Optional[torch.Tensor] = None,
    global_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Fake quantize the input tensor x by quantizing then dequantizing with
    the QuantizationStrategy specified in args. Quantization can be done per tensor,
    channel, token or group. For group quantization, the group_size must be divisible
    by the column size. The input scale  and zero_points are reshaped to support
    vectorization (Assumes 1 is the channel dimension)

    :param x: Input tensor
    :param scale: scale tensor
    :param zero_point: zero point tensor
    :param args: quantization args dictating how to quantize x
    :param g_idx: optional mapping from column index to group index
    :param global_scale: optional constant to scale the quantization scale during QDQ
    :return: fake quantized tensor
    """
    return _process_quantization(
        x=x,
        scale=scale,
        zero_point=zero_point,
        args=args,
        do_quantize=True,
        do_dequantize=True,
        g_idx=g_idx,
        global_scale=global_scale,
    )


@torch.no_grad()
def _process_quantization(
    x: torch.Tensor,
    scale: torch.Tensor,
    zero_point: torch.Tensor,
    args: QuantizationArgs,
    g_idx: Optional[torch.Tensor] = None,
    dtype: Optional[torch.dtype] = None,
    do_quantize: bool = True,
    do_dequantize: bool = True,
    global_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    q_min, q_max = calculate_range(args, x.device)
    group_size = args.group_size

    # blockwise FP8: quantize per 2D block, supports block_structure for static block
    # quantization
    if args.strategy == QuantizationStrategy.BLOCK:
        original_shape = x.shape
        rows, cols = x.shape[-2], x.shape[-1]
        block_height, block_width = args.block_structure

        # Ensure exact division (tensor dimensions must be divisible by block size)
        if rows % block_height != 0:
            raise ValueError(
                f"Tensor height {rows} is not divisible by block_height {block_height}."
                f" Block quantization requires exact division."
            )
        if cols % block_width != 0:
            raise ValueError(
                f"Tensor width {cols} is not divisible by block_width {block_width}. "
                f"Block quantization requires exact division."
            )

        # reshape into blocks and transpose to make each block contiguous
        num_rows_blocks = rows // block_height
        num_cols_blocks = cols // block_width
        x_blocks = x.reshape(
            num_rows_blocks,
            block_height,
            num_cols_blocks,
            block_width,
        ).transpose(1, 2)

        # expand scale/zero_point for blocks
        sb = scale.unsqueeze(-1).unsqueeze(-1)
        zb = zero_point.unsqueeze(-1).unsqueeze(-1) if zero_point is not None else None
        if do_quantize:
            # quantize blocks
            x_blocks = _quantize(
                x=x_blocks,
                scale=sb,
                zero_point=zb,
                q_min=q_min,
                q_max=q_max,
                args=args,
                dtype=dtype,
                global_scale=global_scale,
            )
        if do_dequantize:
            # dequantize blocks
            x_blocks = _dequantize(
                x_q=x_blocks,
                scale=sb,
                zero_point=zb,
                global_scale=global_scale,
            )
        # restore original shape
        output = x_blocks.transpose(1, 2).reshape(original_shape)
    elif args.strategy in (
        QuantizationStrategy.GROUP,
        QuantizationStrategy.TENSOR_GROUP,
    ):

        output_dtype = dtype if dtype is not None else x.dtype
        output = torch.zeros_like(x).to(output_dtype)
        columns = output.shape[-1]

        # TODO: make validation step for inputs

        while scale.ndim < 2:
            # pad scale and zero point dims for slicing
            scale = scale.unsqueeze(1)
            zero_point = zero_point.unsqueeze(1) if zero_point is not None else None

        if columns >= group_size:
            if columns % group_size != 0:
                raise ValueError(
                    "tensor column shape must be divisble "
                    f"by the given group_size {group_size} but got {columns}"
                )

        # support column-order (default) quantization as well as other orderings
        # such as activation ordering. Below checks if g_idx has been initialized
        is_column_order = g_idx is None or -1 in g_idx
        if is_column_order:
            num_groups = int(ceil(columns / group_size))
            group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)

        else:
            group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
            group_sizes = group_sizes[torch.argsort(group_indices)]

            perm = torch.argsort(g_idx)
            x = x.index_select(-1, perm)

        # Maintain all dimensions except the last dim, which is divided by group_size
        reshaped_dims = (
            ceil(x.shape[-1] / group_size),
            group_size,
        )
        x = x.unflatten(-1, reshaped_dims)

        if do_quantize:
            output = _quantize(
                x=x,
                scale=scale.unsqueeze(-1),
                zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None,
                dtype=dtype,
                global_scale=global_scale,
                q_min=q_min,
                q_max=q_max,
                args=args,
            )

        if do_dequantize:
            input = output if do_quantize else x
            output = _dequantize(
                x_q=input,
                scale=scale.unsqueeze(-1),
                zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None,
                global_scale=global_scale,
            )

        output = output.flatten(start_dim=-2)
        output = output.to(output_dtype)

        if not is_column_order:
            inv_perm = torch.argsort(perm)
            output = output.index_select(-1, inv_perm)

    else:  # covers tensor, channel, token, and attn_head strategies
        if do_quantize:
            output = _quantize(
                x=x,
                scale=scale,
                zero_point=zero_point,
                q_min=q_min,
                q_max=q_max,
                args=args,
                dtype=dtype,
                global_scale=global_scale,
            )
        if do_dequantize:
            output = _dequantize(
                output if do_quantize else x,
                scale=scale,
                zero_point=zero_point,
                global_scale=global_scale,
            )

    return output


def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
    # expects a module already initialized and injected with the parameters in
    # initialize_module_for_quantization
    if hasattr(module.forward, "__func__"):
        forward_func_orig = module.forward.__func__
    else:
        forward_func_orig = module.forward.func

    @wraps(forward_func_orig)  # ensures docstring, names, etc are propagated
    def wrapped_forward(self, *args, **kwargs):
        if not getattr(module, "quantization_enabled", True):
            # quantization is disabled on forward passes, return baseline
            # forward call
            return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)

        input_ = args[0]

        compressed = module.quantization_status == QuantizationStatus.COMPRESSED

        if scheme.input_activations is not None:
            # prehook should calibrate activations before forward call
            input_ = forward_quantize(module, input_, "input", scheme.input_activations)

        if scheme.weights is not None and not compressed:
            # calibrate and (fake) quantize weights when applicable
            unquantized_weight = self.weight.data.clone()
            self.weight.data = forward_quantize(
                module, self.weight, "weight", scheme.weights
            )

        # perform wrapped forward call
        output = forward_func_orig.__get__(module, module.__class__)(
            input_, *args[1:], **kwargs
        )

        # restore back to unquantized_value
        if scheme.weights is not None and not compressed:
            self.weight.data = unquantized_weight

        if scheme.output_activations is not None:
            # forward-hook should calibrate/forward_quantize
            if (
                module.quantization_status == QuantizationStatus.CALIBRATION
                and not scheme.output_activations.dynamic
            ):
                return output

            output = forward_quantize(
                module, output, "output", scheme.output_activations
            )
        return output

    # bind wrapped forward to module class so reference to `self` is correct
    bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
    # set forward to wrapped forward
    setattr(module, "forward", bound_wrapped_forward)


def forward_quantize(
    module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
) -> torch.Tensor:

    # in compressed mode, the weight is already compressed and quantized so we don't
    # need to run fake quantization
    if (
        module.quantization_status == QuantizationStatus.COMPRESSED
        and base_name == "weight"
    ):
        return value

    if value.numel() == 0:
        # if the tensor is empty,
        # skip quantization
        return value

    g_idx = getattr(module, "weight_g_idx", None)
    global_scale = getattr(module, f"{base_name}_global_scale", None)

    if args.dynamic in (True, DynamicType.LOCAL):
        # dynamic quantization - determine the scale/zp on the fly
        scale, zero_point = compute_dynamic_scales_and_zp(
            value=value, args=args, module=module, global_scale=global_scale
        )
    else:
        # static quantization - get scale and zero point from layer
        scale = getattr(module, f"{base_name}_scale")
        zero_point = getattr(module, f"{base_name}_zero_point", None)

    return fake_quantize(
        x=value,
        scale=scale,
        zero_point=zero_point,
        args=args,
        g_idx=g_idx,
        global_scale=global_scale,
    )


@torch.no_grad()
def _quantize(
    x: torch.Tensor,
    scale: torch.Tensor,
    zero_point: torch.Tensor,
    q_min: torch.Tensor,
    q_max: torch.Tensor,
    args: QuantizationArgs,
    dtype: Optional[torch.dtype] = None,
    global_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:

    # if a global scale is optionally provided, use it
    # to further scale the local `scale` parameter
    if global_scale is not None:
        scale = scale / global_scale

    scaled = x / scale

    if zero_point is not None:
        scaled += zero_point.to(x.dtype)

    # clamp and round
    quantized_value = round_to_quantized_type_args(
        tensor=scaled, args=args, min=q_min, max=q_max
    )

    if dtype is not None:
        quantized_value = quantized_value.to(dtype)

    return quantized_value


@torch.no_grad()
def _dequantize(
    x_q: torch.Tensor,
    scale: torch.Tensor,
    zero_point: torch.Tensor = None,
    dtype: Optional[torch.dtype] = None,
    global_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:

    # if a global scale is optionally provided, use it
    # to further scale the local `scale` parameter
    if global_scale is not None:
        scale = scale / global_scale

    dequant_value = x_q.to(scale.dtype)

    if zero_point is not None:
        dequant_value = dequant_value - zero_point.to(scale.dtype)

    dequant_value = dequant_value * scale

    if dtype is not None:
        dequant_value = dequant_value.to(dtype)

    return dequant_value
